diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index ac3e47d7..cca2b5b6 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -43,8 +43,11 @@ py_library( srcs = ["gradient_clipping_utils.py"], srcs_version = "PY3", deps = [ + ":common_manip_utils", ":layer_registry", ":type_aliases", + "//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry", + "//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases", ], ) @@ -54,10 +57,7 @@ py_test( python_version = "PY3", shard_count = 8, srcs_version = "PY3", - deps = [ - ":gradient_clipping_utils", - ":type_aliases", - ], + deps = [":gradient_clipping_utils"], ) py_library( @@ -83,6 +83,12 @@ py_library( ], ) +py_library( + name = "noise_utils", + srcs = ["noise_utils.py"], + deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils"], +) + py_test( name = "clip_grads_test", srcs = ["clip_grads_test.py"], @@ -92,7 +98,14 @@ py_test( deps = [ ":clip_grads", ":common_test_utils", + ":gradient_clipping_utils", ":layer_registry", ":type_aliases", ], ) + +py_test( + name = "noise_utils_test", + srcs = ["noise_utils_test.py"], + deps = [":noise_utils"], +) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index a3168499..70edd7d2 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -22,7 +22,7 @@ """ import collections -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Optional import tensorflow as tf @@ -32,110 +32,81 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases -def get_registry_generator_fn( - tape: tf.GradientTape, - layer_registry: lr.LayerRegistry, - num_microbatches: Optional[type_aliases.BatchSize] = None, +def _compute_gradient_norms_internal( + registry_fn_outputs_list: Sequence[ + gradient_clipping_utils.RegistryGeneratorFunctionOutput + ], + layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]], + trainable_vars: Optional[Sequence[tf.Variable]] = None, ): - """Creates the generator function for `compute_gradient_norms()`.""" - if layer_registry is None: - # Needed for backwards compatibility. - registry_generator_fn = None - else: - - def registry_generator_fn(layer_instance, args, kwargs): - if layer_instance.trainable_variables: - # Only trainable variables factor into the gradient. - if not layer_registry.is_elem(layer_instance): - raise NotImplementedError( - 'Layer %s is not in the registry of known layers that can ' - 'be used for efficient gradient clipping.' - % layer_instance.__class__.__name__ - ) - registry_fn = layer_registry.lookup(layer_instance) - (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( - layer_instance, args, kwargs, tape, num_microbatches - ) - return layer_outputs, ( - str(id(layer_instance)), - layer_vars, - layer_sqr_norm_fn, - layer_instance.trainable_weights, - ) - else: - # Non-trainable layer. - return layer_instance(*args, **kwargs), None + """Computes the per-example loss gradient norms for given data. - return registry_generator_fn + Args: + registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput + containing information required to compute the gradient norms and + contribution counts. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. + layer_grad_vars: A mapping of layer id to a list of gradients for each + trainablev ariable in the layer. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. + trainable_vars: The list of variables included in computing the gradient + norm. When a layer has multiple variables, we include all the variables if + any of the variables is in the list. If `trainable_vars` is None, all the + variables are included. + Returns: + A scalar vector, whose i-th entry is the norm of the gradient of the i-th + weighted example loss (when num_microbatches is None) or the norm of the + gradient of the i-th microbatch loss (define as a mean over the microbatch). + Note that when the loss is weighted (`weight_batch` is not None), weights + are applied prior to clipping. -def _infer_per_example_loss_fn(model: tf.keras.Model): - """Infer the per-example loss from model config.""" + Raises: + ValueError: If `layer_grad_vars` is empty. + ValueError: If the number of gradients for a layer is not equal to the + number of squared norm functions for that layer. + """ + if trainable_vars is not None: + # Create a set using `ref()` for fast set membership check. tf.Variable + # itself is not hashable. + trainable_vars = set([v.ref() for v in trainable_vars]) - def _convert(loss_fn): - loss_config = loss_fn.get_config() - loss_config['reduction'] = tf.keras.losses.Reduction.NONE - return loss_fn.from_config(loss_config) + layer_sqr_norm_fns = collections.defaultdict(list) + # The case of shared weights: + # If a layer is called k times, it will appear k times in filtered_outputs, + # with the same id, but potentially with different v and f. The code below + # groups filtered_outputs by layer_id, so we can correctly compute gradient + # norms. The gradient norm of a layer that occurs k times is computed as + # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th + # occurrence. This is an over-estimate of the actual norm. For more details, + # see the explanation in go/dp-sgd-shared-weights. + for registry_fn_output in registry_fn_outputs_list: + if trainable_vars is None or any( + w.ref() in trainable_vars + for w in registry_fn_output.layer_trainable_weights + ): + layer_sqr_norm_fns[registry_fn_output.layer_id].append( + registry_fn_output.layer_sqr_norm_fn + ) - model_loss = model.loss - if isinstance(model_loss, tf.keras.losses.Loss): - return _convert(model_loss) - elif isinstance(model_loss, dict): - # Note that we cannot call the public method `.get_compile_config()` because - # it calls a numpy function, which is not supported inside a `tf.function` - # wrapped function. - compile_config = model._compile_config.config # pylint: disable=protected-access - if compile_config is None: - raise ValueError('Model must be compiled for loss function conversion') - # Does a weighted mean of the configured losses. Note that we cannot build - # from the config of the compiled loss because (i) it builds a - # `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s - # during its construction, (ii) non-unique `tf.Variables` cannot be used - # inside a `tf.function`, which is usually where this function is used. - if 'loss_weights' not in compile_config: + if not layer_grad_vars: + raise ValueError('The gradient list cannot be empty.') + sqr_norm_list = [] + for layer_id in layer_sqr_norm_fns.keys(): + fns = layer_sqr_norm_fns[layer_id] + grads = layer_grad_vars[layer_id] + # Number of duplicates for this layer in `filtered_outputs`. + num_passes = len(fns) + if len(fns) != len(grads): raise ValueError( - 'Models with multiple loss must have corresponding loss weights for' - ' loss function conversion' + 'There must be as many gradients as squared norm functions.' ) - weights = compile_config['loss_weights'] - per_example_losses = {k: _convert(v) for k, v in model_loss.items()} - num_losses = len(weights) - - def _per_example_loss_fn(y_true, y_pred, sample_weight=None): - loss_values = [] - if model_loss.keys() - y_pred.keys(): - raise ValueError( - 'y_pred must contain the same keys and the model losses, but ' - 'got %s and %s' % (y_pred.keys(), model_loss.keys()) - ) - if model_loss.keys() - y_true.keys(): - raise ValueError( - 'y_true must contain the same keys and the model losses, but ' - 'got %s and %s' % (y_true.keys(), model_loss.keys()) - ) - if sample_weight is not None: - if model_loss.keys() - sample_weight.keys(): - raise ValueError( - 'sample_weight must contain the same keys and the model losses,' - ' but got %s and %s' % (y_true.keys(), model_loss.keys()) - ) - for k in y_true.keys(): - sgl_sample_weight = None if sample_weight is None else sample_weight[k] - sgl_value = ( - weights[k] - * per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight) - / num_losses - ) - loss_values.append(tf.reshape(sgl_value, shape=[-1])) - return tf.math.add_n(loss_values) - - return _per_example_loss_fn - else: - raise ValueError( - 'Unsupported type for loss function conversion: {}'.format( - type(model_loss) - ) - ) + # See go/dp-sgd-shared-weights for more details. + for fn, grad in zip(fns, grads): + sqr_norm_list.append(num_passes * fn(grad)) + sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1) + gradient_norms = tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) + return gradient_norms def compute_gradient_norms( @@ -147,7 +118,7 @@ def compute_gradient_norms( per_example_loss_fn: Optional[type_aliases.LossFn] = None, num_microbatches: Optional[type_aliases.BatchSize] = None, trainable_vars: Optional[Sequence[tf.Variable]] = None, -): +) -> tf.Tensor: """Computes the per-example loss gradient norms for given data. Applies a variant of the approach given in @@ -190,86 +161,29 @@ def compute_gradient_norms( are applied prior to clipping. """ tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) - registry_generator_fn = get_registry_generator_fn( - tape, layer_registry, num_microbatches + registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( + tape=tape, + layer_registry=layer_registry, + sparse_noise_layer_registry=None, + num_microbatches=num_microbatches, ) - # First loop computes the model outputs, summed loss, and generator outputs. - with tape: - model_outputs, generator_outputs_list = ( - gradient_clipping_utils.model_forward_pass( - input_model, x_batch, generator_fn=registry_generator_fn - ) - ) - - # Ignore the original loss function's reduction to get per-example loss. - if per_example_loss_fn is None: - per_example_loss_fn = _infer_per_example_loss_fn(input_model) - - losses = per_example_loss_fn(y_batch, model_outputs, weight_batch) - if losses.shape is None: - raise NotImplementedError( - "The unreduced (or per-example) loss's shape cannot be `None`" - ) - if len(losses.shape) != 1: - raise NotImplementedError( - 'The unreduced (or per-example) loss needs to have a shape of length ' - 'one, but received an unreduced loss of shape length %s' - % len(losses.shape) - ) - if num_microbatches is not None: - losses = tf.reduce_mean( - common_manip_utils.maybe_add_microbatch_axis( - losses, num_microbatches - ), - axis=1, + layer_grad_vars, generator_outputs_list = ( + gradient_clipping_utils.model_forward_backward_pass( + tape=tape, + input_model=input_model, + x_batch=x_batch, + y_batch=y_batch, + registry_generator_fn=registry_generator_fn, + weight_batch=weight_batch, + per_example_loss_fn=per_example_loss_fn, + num_microbatches=num_microbatches, ) - summed_loss = tf.reduce_sum(losses) - # Unwrap the generator outputs so that the next loop avoids duplicating - # backprop ops. - filtered_outputs = [t for t in generator_outputs_list if t is not None] - if trainable_vars is not None: - # Create a set using `ref()` for fast set membership check. tf.Variable - # itself is not hashable. - trainable_vars = set([v.ref() for v in trainable_vars]) - layer_vars = collections.defaultdict(list) - layer_sqr_norm_fns = collections.defaultdict(list) - # The case of shared weights: - # If a layer is called k times, it will appear k times in filtered_outputs, - # with the same id, but potentially with different v and f. The code below - # groups filtered_outputs by layer_id, so we can correctly compute gradient - # norms. The gradient norm of a layer that occurs k times is computed as - # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th - # occurrence. This is an over-estimate of the actual norm. For more details, - # see the explanation in go/dp-sgd-shared-weights. - for layer_id, v, f, weights_list in filtered_outputs: - if trainable_vars is None or any( - w.ref() in trainable_vars for w in weights_list - ): - layer_vars[layer_id].append(v) - layer_sqr_norm_fns[layer_id].append(f) - # Second loop evaluates the squared L2 norm functions and appends the results. - layer_grad_vars = tape.gradient( - summed_loss, - layer_vars, - unconnected_gradients=tf.UnconnectedGradients.ZERO, ) - if not layer_grad_vars: - raise ValueError('The gradient list cannot be empty.') - sqr_norm_list = [] - for layer_id in layer_sqr_norm_fns.keys(): - fns = layer_sqr_norm_fns[layer_id] - grads = layer_grad_vars[layer_id] - # Number of duplicates for this layer in `filtered_outputs`. - num_passes = len(fns) - if len(fns) != len(grads): - raise ValueError( - 'There must be as many gradients as squared norm functions.' - ) - # See go/dp-sgd-shared-weights for more details. - for fn, grad in zip(fns, grads): - sqr_norm_list.append(num_passes * fn(grad)) - sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1) - return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) + return _compute_gradient_norms_internal( + registry_fn_outputs_list=generator_outputs_list, + layer_grad_vars=layer_grad_vars, + trainable_vars=trainable_vars, + ) def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): @@ -299,14 +213,17 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): def compute_clipped_gradients_and_outputs( input_model: tf.keras.Model, + registry_fn_outputs_list: Sequence[ + gradient_clipping_utils.RegistryGeneratorFunctionOutput + ], + layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]], l2_norm_clip: float, - layer_registry: lr.LayerRegistry, x_batch: type_aliases.InputTensors, y_batch: type_aliases.OutputTensors, weight_batch: Optional[tf.Tensor] = None, num_microbatches: Optional[type_aliases.BatchSize] = None, clipping_loss: Optional[type_aliases.LossFn] = None, -) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]: +) -> tuple[Sequence[type_aliases.Tensor], tf.Tensor, tf.Tensor]: """Computes the per-example clipped loss gradient and other useful outputs. Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main @@ -319,15 +236,16 @@ def compute_clipped_gradients_and_outputs( Args: input_model: The `tf.keras.Model` from which to obtain the layers from. + registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput + containing information required to compute the gradient norms and + contribution counts. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. + layer_grad_vars: A mapping of layer id to a list of gradients for each + trainablev ariable in the layer. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. l2_norm_clip: A `float` indicating the norm to which per-example gradients will be clipped. That is, all gradients of the per-example loss functions will have norm at most `l2_norm_clip`. - layer_registry: A `dict` of layers that support "fast" gradient norm - computations. The key is the class of the layer and the value is a - function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where - `output` is the pre-activator tensor, `sqr_grad_norms` is related to the - squared norms of a layer's pre-activation tensor, and `vars` are relevant - trainable weights (see `layer_registry_factories.py` for examples). x_batch: An `InputTensor` representing a batch of inputs to the model. The first axes of each tensor must be the batch dimension. y_batch: An `OutputTensor` representing a batch of output labels. The first @@ -362,13 +280,9 @@ def compute_clipped_gradients_and_outputs( ) if clipping_loss is None: clipping_loss = input_model.compiled_loss - gradient_norms = compute_gradient_norms( - input_model, - layer_registry, - x_batch, - y_batch, - weight_batch, - num_microbatches=num_microbatches, + gradient_norms = _compute_gradient_norms_internal( + registry_fn_outputs_list=registry_fn_outputs_list, + layer_grad_vars=layer_grad_vars, trainable_vars=input_model.trainable_variables, ) clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index a028affe..6ec47941 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases @@ -122,6 +123,30 @@ def test_gradient_norms_on_various_models( self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) +def _run_model_forward_backward_pass( + model: tf.keras.Model, + x_batch: type_aliases.InputTensors, + y_batch: type_aliases.OutputTensors, +): + tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) + registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( + tape=tape, + layer_registry=layer_registry.make_default_layer_registry(), + sparse_noise_layer_registry=None, + num_microbatches=None, + ) + layer_grad_vars, registry_fn_outputs_list = ( + gradient_clipping_utils.model_forward_backward_pass( + tape=tape, + input_model=model, + x_batch=x_batch, + y_batch=y_batch, + registry_generator_fn=registry_generator_fn, + ) + ) + return layer_grad_vars, registry_fn_outputs_list + + class ComputeClippedGradsAndOutputsTest( tf.test.TestCase, parameterized.TestCase ): @@ -153,13 +178,17 @@ def test_clipped_gradients_on_different_losses( y_batch = tf.reshape( 1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1] ) + layer_grad_vars, registry_fn_outputs_list = ( + _run_model_forward_backward_pass(self._model, x_batch, y_batch) + ) # Stop early for efficiency. if reduction == 'none': with self.assertRaises(NotImplementedError): clip_grads.compute_clipped_gradients_and_outputs( self._model, + registry_fn_outputs_list, + layer_grad_vars, l2_norm_clip, - layer_registry.make_default_layer_registry(), x_batch, y_batch, ) @@ -169,10 +198,12 @@ def test_clipped_gradients_on_different_losses( y_pred = self._model(x_batch) loss_value = loss_fn(y_pred, y_batch) true_grads = tape.gradient(loss_value, self._model.trainable_variables) + clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs( self._model, + registry_fn_outputs_list, + layer_grad_vars, l2_norm_clip, - layer_registry.make_default_layer_registry(), x_batch, y_batch, ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 74163187..989a69c2 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -13,13 +13,28 @@ # limitations under the License. """Utility functions that help in the computation of per-example gradient norms.""" -from collections.abc import Sequence, Set -from typing import Any, Literal, Optional +import collections +from collections.abc import Callable, Sequence, Set +import dataclasses +from typing import Any, Optional, Tuple -from absl import logging import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr +from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases as sn_type_aliases + + +@dataclasses.dataclass(frozen=True) +class RegistryGeneratorFunctionOutput: + layer_id: str + layer_vars: Optional[Sequence[tf.Variable]] + layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction] + varname_to_count_contribution_fn: Optional[ + dict[str, sn_type_aliases.ContributionCountHistogramFn] + ] + layer_trainable_weights: Optional[Sequence[tf.Variable]] def has_internal_compute_graph(input_object: Any): @@ -33,6 +48,222 @@ def has_internal_compute_graph(input_object: Any): ) +def get_registry_generator_fn( + tape: tf.GradientTape, + layer_registry: lr.LayerRegistry, + sparse_noise_layer_registry: snlr.LayerRegistry, + num_microbatches: Optional[type_aliases.BatchSize] = None, +) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]: + """Creates the generator function for `model_forward_backward_pass()`. + + Args: + tape: The `tf.GradientTape` to use for the gradient computation. + layer_registry: A `dict` of layers that support "fast" gradient norm + computations. The key is the class of the layer and the value is a + function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where + `output` is the pre-activator tensor, `sqr_grad_norms` is related to the + squared norms of a layer's pre-activation tensor, and `vars` are relevant + trainable + sparse_noise_layer_registry: A `LayerRegistry` instance containing functions + that help compute contribution counts for sparse noise. See + `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for + more details. + num_microbatches: An optional number or scalar `tf.Tensor` for the number of + microbatches. If not None, indicates that the loss is grouped into + num_microbatches (in this case, the batch dimension needs to be a multiple + of num_microbatches). + + Returns: + A function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where + `output` is the pre-activator tensor, `sqr_grad_norms` is related to the + squared norms of a layer's pre-activation tensor, and `vars` are relevant + trainable variables. + """ + if layer_registry is None: + # Needed for backwards compatibility. + registry_generator_fn = None + else: + + def registry_generator_fn(layer_instance, args, kwargs): + if layer_instance.trainable_variables: + # Only trainable variables factor into the gradient. + if not layer_registry.is_elem(layer_instance): + raise NotImplementedError( + 'Layer %s is not in the registry of known layers that can ' + 'be used for efficient gradient clipping.' + % layer_instance.__class__.__name__ + ) + varname_to_count_contribution_fn = None + if sparse_noise_layer_registry and sparse_noise_layer_registry.is_elem( + layer_instance + ): + count_contribution_registry_fn = sparse_noise_layer_registry.lookup( + layer_instance + ) + varname_to_count_contribution_fn = count_contribution_registry_fn( + layer_instance, args, kwargs, num_microbatches + ) + registry_fn = layer_registry.lookup(layer_instance) + (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( + layer_instance, args, kwargs, tape, num_microbatches + ) + return layer_outputs, RegistryGeneratorFunctionOutput( + layer_id=str(id(layer_instance)), + layer_vars=layer_vars, + layer_sqr_norm_fn=layer_sqr_norm_fn, + varname_to_count_contribution_fn=varname_to_count_contribution_fn, + layer_trainable_weights=layer_instance.trainable_weights, + ) + else: + # Non-trainable layer. + return layer_instance(*args, **kwargs), None + + return registry_generator_fn + + +def _infer_per_example_loss_fn(model: tf.keras.Model): + """Infer the per-example loss from model config.""" + + def _convert(loss_fn): + loss_config = loss_fn.get_config() + loss_config['reduction'] = tf.keras.losses.Reduction.NONE + return loss_fn.from_config(loss_config) + + model_loss = model.loss + if isinstance(model_loss, tf.keras.losses.Loss): + return _convert(model_loss) + elif isinstance(model_loss, dict): + # Note that we cannot call the public method `.get_compile_config()` because + # it calls a numpy function, which is not supported inside a `tf.function` + # wrapped function. + compile_config = model._compile_config.config # pylint: disable=protected-access + if compile_config is None: + raise ValueError('Model must be compiled for loss function conversion') + # Does a weighted mean of the configured losses. Note that we cannot build + # from the config of the compiled loss because (i) it builds a + # `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s + # during its construction, (ii) non-unique `tf.Variables` cannot be used + # inside a `tf.function`, which is usually where this function is used. + if 'loss_weights' not in compile_config: + raise ValueError( + 'Models with multiple loss must have corresponding loss weights for' + ' loss function conversion' + ) + weights = compile_config['loss_weights'] + per_example_losses = {k: _convert(v) for k, v in model_loss.items()} + num_losses = len(weights) + + def _per_example_loss_fn(y_true, y_pred, sample_weight=None): + loss_values = [] + if model_loss.keys() - y_pred.keys(): + raise ValueError( + 'y_pred must contain the same keys and the model losses, but ' + 'got %s and %s' % (y_pred.keys(), model_loss.keys()) + ) + if model_loss.keys() - y_true.keys(): + raise ValueError( + 'y_true must contain the same keys and the model losses, but ' + 'got %s and %s' % (y_true.keys(), model_loss.keys()) + ) + if sample_weight is not None: + if model_loss.keys() - sample_weight.keys(): + raise ValueError( + 'sample_weight must contain the same keys and the model losses,' + ' but got %s and %s' % (y_true.keys(), model_loss.keys()) + ) + for k in y_true.keys(): + sgl_sample_weight = None if sample_weight is None else sample_weight[k] + sgl_value = ( + weights[k] + * per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight) + / num_losses + ) + loss_values.append(tf.reshape(sgl_value, shape=[-1])) + return tf.math.add_n(loss_values) + + return _per_example_loss_fn + else: + raise ValueError( + 'Unsupported type for loss function conversion: {}'.format( + type(model_loss) + ) + ) + + +def model_forward_backward_pass( + tape: tf.GradientTape, + input_model: tf.keras.Model, + x_batch: type_aliases.InputTensors, + y_batch: type_aliases.OutputTensors, + registry_generator_fn: Optional[ + Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]] + ], + weight_batch: Optional[tf.Tensor] = None, + per_example_loss_fn: Optional[type_aliases.LossFn] = None, + num_microbatches: Optional[type_aliases.BatchSize] = None, + trainable_vars: Optional[Sequence[tf.Variable]] = None, +) -> tuple[ + dict[str, list[type_aliases.Tensor]], list[RegistryGeneratorFunctionOutput] +]: + """Does a forward and backward pass of a model and returns useful intermediates.""" + # First loop computes the model outputs, summed loss, and generator outputs. + with tape: + model_outputs, generator_outputs_list = model_forward_pass( + input_model, x_batch, generator_fn=registry_generator_fn + ) + + # Ignore the original loss function's reduction to get per-example loss. + if per_example_loss_fn is None: + per_example_loss_fn = _infer_per_example_loss_fn(input_model) + + losses = per_example_loss_fn(y_batch, model_outputs, weight_batch) + if losses.shape is None: + raise NotImplementedError( + "The unreduced (or per-example) loss's shape cannot be `None`" + ) + if len(losses.shape) != 1: + raise NotImplementedError( + 'The unreduced (or per-example) loss needs to have a shape of length ' + 'one, but received an unreduced loss of shape length %s' + % len(losses.shape) + ) + if num_microbatches is not None: + losses = tf.reduce_mean( + common_manip_utils.maybe_add_microbatch_axis( + losses, num_microbatches + ), + axis=1, + ) + summed_loss = tf.reduce_sum(losses) + # Unwrap the generator outputs so that the next loop avoids duplicating + # backprop ops. + filtered_outputs = [t for t in generator_outputs_list if t is not None] + + if trainable_vars is not None: + # Create a set using `ref()` for fast set membership check. tf.Variable + # itself is not hashable. + trainable_vars = set([v.ref() for v in trainable_vars]) + layer_vars = collections.defaultdict(list) + for registry_fn_output in filtered_outputs: + if trainable_vars is None or any( + w.ref() in trainable_vars + for w in registry_fn_output.layer_trainable_weights + ): + layer_vars[registry_fn_output.layer_id].append( + registry_fn_output.layer_vars + ) + + layer_grad_vars = tape.gradient( + summed_loss, + layer_vars, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) + if not layer_grad_vars: + raise ValueError('The gradient list cannot be empty.') + + return layer_grad_vars, filtered_outputs + + def model_forward_pass( input_model: tf.keras.Model, inputs: type_aliases.PackedTensors, @@ -144,101 +375,6 @@ def all_trainable_layers_are_registered( return True -def _infer_loss_reduction_type(model: tf.keras.Model): - """Infers what type of loss reduction is being performed.""" - model_loss = model.loss - if isinstance(model_loss, tf.keras.losses.Loss): - return model_loss.reduction - elif isinstance(model.loss, dict): - reductions = set() - compiled_loss = model.compiled_loss - if compiled_loss is None: - raise ValueError('Model must be compiled for adding noise') - new_config_list = compiled_loss.get_config()['losses'] - for loss_config in new_config_list: - reductions.add(loss_config['config']['reduction']) - if len(reductions) > 1: - raise ValueError( - 'Reductions in models with multiple losses must all be the same' - ) - return reductions.pop() - else: - raise ValueError( - 'Unsupported type for adding noise: {}'.format(type(model_loss)) - ) - - -def add_aggregate_noise( - clipped_grads: list[tf.Tensor], - batch_size: tf.Tensor, - l2_norm_clip: float, - noise_multiplier: float, - loss_reduction: Optional[Literal['mean', 'sum']] = None, - loss_model: Optional[tf.keras.Model] = None, -) -> Sequence[tf.Tensor]: - """Adds noise to a collection of clipped gradients. - - The magnitude of the noise depends on the aggregation strategy of the - input model's loss function. - - Args: - clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. - batch_size: The batch size. Used for normalizing the noise when - `loss_reduction` is 'sum'. - l2_norm_clip: Clipping norm (max L2 norm of each gradient). - noise_multiplier: Ratio of the standard deviation to the clipping norm. - loss_reduction: An string description of how the loss is reduced over - examples. Currently supports 'mean' and 'sum'. If `None`, then the - aggregation type must be inferred from `input_model.loss`. - loss_model: An optional `tf.keras.Model` used to infer the loss reduction - strategy from if `loss_reduction` is `None`. - - Returns: - A list of tensors containing the clipped gradients, but with the right - amount of Gaussian noise added to them (depending on the reduction - strategy of the loss function). - - Raises: - ValueError: If both `loss_model` and `loss_reduction` are `None` or if - they are both not `None`. - """ - if loss_reduction is None and loss_model is None: - raise ValueError( - 'Exactly one of `loss_reduction` and `loss_model` must be populated.' - ' Instead, both arguments were `None`.' - ) - if loss_reduction is not None and loss_model is not None: - raise ValueError( - 'Exactly one of `loss_reduction` and `loss_model` must be populated.' - ' Instead, both arguments were not `None`.' - ) - - if loss_reduction is None and loss_model is not None: - implicit_mean_reductions = [ - tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, - tf.keras.losses.Reduction.AUTO, - ] - model_reduction = _infer_loss_reduction_type(loss_model) - loss_reduction = ( - 'mean' if model_reduction in implicit_mean_reductions else 'sum' - ) - if model_reduction == tf.keras.losses.Reduction.AUTO: - logging.info( - 'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.' - ) - - scale = l2_norm_clip - if loss_reduction == 'mean': - scale /= tf.cast(batch_size, tf.float32) - - def add_noise(g): - return g + tf.random.normal( - tf.shape(g), mean=0.0, stddev=noise_multiplier * scale - ) - - return tf.nest.map_structure(add_noise, clipped_grads) - - def generate_model_outputs_using_core_keras_layers( input_model: tf.keras.Model, custom_layer_set: Optional[Set[type]] = None, # pylint: disable=g-bare-generic diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py index f7cf3b0f..7069273d 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py @@ -15,7 +15,6 @@ from typing import Any from absl.testing import parameterized -import numpy as np import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils @@ -135,60 +134,6 @@ def test_outputs_are_consistent( self.assertAllClose(computed_outputs, true_outputs) -class AddAggregateNoise(tf.test.TestCase, parameterized.TestCase): - - @parameterized.product( - l2_norm_clip=[3.0, 5.0], - noise_multiplier=[2.0, 4.0], - batch_size=[1, 2, 10], - model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'], - noise_fn_reduction=[None, 'mean', 'sum'], - ) - def test_noise_is_computed_correctly( - self, - l2_norm_clip, - noise_multiplier, - batch_size, - model_fn_reduction, - noise_fn_reduction, - ): - # Skip invalid combinations. - if model_fn_reduction is None and noise_fn_reduction is None: - return - if model_fn_reduction is not None and noise_fn_reduction is not None: - return - # Make an simple model container for storing the loss. - if model_fn_reduction is not None: - linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)]) - linear_model.compile( - loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction) - ) - else: - linear_model = None - # The main computation is done on a deterministic dummy vector. - num_units = 100 - clipped_grads = [ - tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1) - ] - noised_grads = gradient_clipping_utils.add_aggregate_noise( - clipped_grads, - batch_size, - l2_norm_clip, - noise_multiplier, - noise_fn_reduction, - linear_model, - ) - # The only measure that varies is the standard deviation of the variation. - scale = ( - 1.0 - if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum' - else 1.0 / batch_size - ) - computed_std = np.std(noised_grads[0] - clipped_grads[0]) - expected_std = l2_norm_clip * noise_multiplier * scale - self.assertNear(computed_std, expected_std, 0.1 * expected_std) - - class GenerateOutputsUsingCoreKerasLayers( tf.test.TestCase, parameterized.TestCase ): diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py new file mode 100644 index 00000000..22fc8ea5 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py @@ -0,0 +1,146 @@ +# Copyright 2022, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions that help in adding noise to gradients.""" + +from collections.abc import Sequence +from typing import Literal, Optional + +from absl import logging +import tensorflow as tf +from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils + + +def _infer_loss_reduction_type(model: tf.keras.Model): + """Infers what type of loss reduction is being performed.""" + model_loss = model.loss + if isinstance(model_loss, tf.keras.losses.Loss): + return model_loss.reduction + elif isinstance(model.loss, dict): + reductions = set() + compiled_loss = model.compiled_loss + if compiled_loss is None: + raise ValueError('Model must be compiled for adding noise') + new_config_list = compiled_loss.get_config()['losses'] + for loss_config in new_config_list: + reductions.add(loss_config['config']['reduction']) + if len(reductions) > 1: + raise ValueError( + 'Reductions in models with multiple losses must all be the same' + ) + return reductions.pop() + else: + raise ValueError( + 'Unsupported type for adding noise: {}'.format(type(model_loss)) + ) + + +def add_aggregate_noise( + clipped_grads: list[tf.Tensor | tf.IndexedSlices], + batch_size: tf.Tensor, + l2_norm_clip: float, + noise_multiplier: float, + loss_reduction: Optional[Literal['mean', 'sum']] = None, + loss_model: Optional[tf.keras.Model] = None, + use_sparse_noise: bool = False, + sparse_noise_multiplier: float = 0.0, + sparse_selection_threshold: int = 0, + sparse_contribution_counts: Optional[Sequence[tf.SparseTensor]] = None, +) -> Sequence[tf.Tensor | tf.IndexedSlices]: + """Adds noise to a collection of clipped gradients. + + The magnitude of the noise depends on the aggregation strategy of the + input model's loss function. + + Args: + clipped_grads: A list of `tf.Tensor`s or `tf.IndexedSlices`s representing + the clipped gradients. + batch_size: The batch size. Used for normalizing the noise when + `loss_reduction` is 'sum'. + l2_norm_clip: Clipping norm (max L2 norm of each gradient). + noise_multiplier: Ratio of the standard deviation to the clipping norm. + loss_reduction: An string description of how the loss is reduced over + examples. Currently supports 'mean' and 'sum'. If `None`, then the + aggregation type must be inferred from `input_model.loss`. + loss_model: An optional `tf.keras.Model` used to infer the loss reduction + strategy from if `loss_reduction` is `None`. + use_sparse_noise: Whether to use sparse noise. + sparse_noise_multiplier: The multiplier for the sparse noise. + sparse_selection_threshold: The threshold for the sparse noise. + sparse_contribution_counts: A list of `tf.Tensor`s representing the + contribution counts for each sparse gradient in clipped_grads. + + Returns: + A list of tensors containing the clipped gradients, but with the right + amount of Gaussian or sparse Gaussain noise added to them (depending on + the reduction strategy of the loss function). + + Raises: + ValueError: If both `loss_model` and `loss_reduction` are `None` or if + they are both not `None`. + """ + if loss_reduction is None and loss_model is None: + raise ValueError( + 'Exactly one of `loss_reduction` and `loss_model` must be populated.' + ' Instead, both arguments were `None`.' + ) + if loss_reduction is not None and loss_model is not None: + raise ValueError( + 'Exactly one of `loss_reduction` and `loss_model` must be populated.' + ' Instead, both arguments were not `None`.' + ) + + if loss_reduction is None and loss_model is not None: + implicit_mean_reductions = [ + tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, + tf.keras.losses.Reduction.AUTO, + ] + model_reduction = _infer_loss_reduction_type(loss_model) + loss_reduction = ( + 'mean' if model_reduction in implicit_mean_reductions else 'sum' + ) + if model_reduction == tf.keras.losses.Reduction.AUTO: + logging.info( + 'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.' + ) + + if sparse_contribution_counts is None: + sparse_contribution_counts = tf.nest.map_structure( + lambda x: None, clipped_grads + ) + + scale = l2_norm_clip + if loss_reduction == 'mean': + scale /= tf.cast(batch_size, tf.float32) + + def add_noise(grad, contribution_counts): + if ( + use_sparse_noise + and isinstance(grad, tf.IndexedSlices) + and contribution_counts is not None + ): + return sparse_noise_utils.add_sparse_noise( + grad, + contribution_counts, + noise_multiplier, + sparse_noise_multiplier, + l2_norm_clip, + sparse_selection_threshold, + ) + return grad + tf.random.normal( + tf.shape(grad), mean=0.0, stddev=noise_multiplier * scale + ) + + return tf.nest.map_structure( + add_noise, clipped_grads, sparse_contribution_counts + ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py new file mode 100644 index 00000000..746351e9 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py @@ -0,0 +1,161 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils + + +class NoiseUtilsTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + l2_norm_clip=[3.0, 5.0], + noise_multiplier=[2.0, 4.0], + batch_size=[1, 2, 10], + model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'], + noise_fn_reduction=[None, 'mean', 'sum'], + ) + def test_noise_is_computed_correctly( + self, + l2_norm_clip, + noise_multiplier, + batch_size, + model_fn_reduction, + noise_fn_reduction, + ): + # Skip invalid combinations. + if model_fn_reduction is None and noise_fn_reduction is None: + return + if model_fn_reduction is not None and noise_fn_reduction is not None: + return + # Make an simple model container for storing the loss. + if model_fn_reduction is not None: + linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)]) + linear_model.compile( + loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction) + ) + else: + linear_model = None + # The main computation is done on a deterministic dummy vector. + num_units = 100 + clipped_grads = [ + tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1) + ] + noised_grads = noise_utils.add_aggregate_noise( + clipped_grads, + batch_size, + l2_norm_clip, + noise_multiplier, + noise_fn_reduction, + linear_model, + ) + # The only measure that varies is the standard deviation of the variation. + scale = ( + 1.0 + if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum' + else 1.0 / batch_size + ) + computed_std = np.std(noised_grads[0] - clipped_grads[0]) + expected_std = l2_norm_clip * noise_multiplier * scale + self.assertNear(computed_std, expected_std, 0.1 * expected_std) + + @parameterized.product( + l2_norm_clip=[3.0, 5.0], + noise_multiplier=[2.0, 4.0], + sparse_noise_multiplier=[1.0], + batch_size=[1, 2, 10], + model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'], + noise_fn_reduction=[None, 'mean', 'sum'], + ) + def test_sparse_noise_is_computed_correctly( + self, + l2_norm_clip, + noise_multiplier, + sparse_noise_multiplier, + batch_size, + model_fn_reduction, + noise_fn_reduction, + ): + # Skip invalid combinations. + if model_fn_reduction is None and noise_fn_reduction is None: + return + if model_fn_reduction is not None and noise_fn_reduction is not None: + return + # Make an simple model container for storing the loss. + if model_fn_reduction is not None: + linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)]) + linear_model.compile( + loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction) + ) + else: + linear_model = None + # The main computation is done on a deterministic dummy vector. + num_units = 100 + dense_grad = tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1) + sparse_grad = tf.IndexedSlices( + values=tf.ones((3, 4)), + indices=tf.constant([0, 3, 5]), + dense_shape=tf.constant([8, 4]), + ) + sparse_grad_contribution_counts = tf.SparseTensor( + indices=[[0], [3], [5]], + values=[10.0, 10.0, 20.0], + dense_shape=[8], + ) + + sparse_noised_grad, dense_noised_grad = noise_utils.add_aggregate_noise( + clipped_grads=[dense_grad, sparse_grad], + batch_size=batch_size, + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + loss_model=linear_model, + use_sparse_noise=True, + sparse_noise_multiplier=sparse_noise_multiplier, + sparse_selection_threshold=8, + sparse_contribution_counts=[None, sparse_grad_contribution_counts], + ) + self.assertContainsSubset( + sparse_grad.indices.numpy().tolist(), + sparse_noised_grad.indices.numpy().tolist(), + ) + sparse_noised_grad_dense = tf.scatter_nd( + tf.reshape(sparse_noised_grad.indices, (-1, 1)), + sparse_noised_grad.values, + shape=(8, 4), + ).numpy() + sparse_noised_grad_valid_indices = sparse_noised_grad_dense[ + sparse_grad.indices.numpy() + ] + sparse_grad_values = sparse_grad.values.numpy() + self.assertTrue( + np.all( + np.not_equal(sparse_noised_grad_valid_indices, sparse_grad_values) + ) + ) + scale = ( + 1.0 + if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum' + else 1.0 / batch_size + ) + # The only measure that varies is the standard deviation of the variation. + expected_std = l2_norm_clip * noise_multiplier * scale + + sparse_computed_std = np.std( + sparse_noised_grad_valid_indices - sparse_grad_values + ) + self.assertNear(sparse_computed_std, expected_std, 0.1 * expected_std) + + dense_computed_std = np.std(dense_noised_grad - dense_grad) + self.assertNear(dense_computed_std, expected_std, 0.1 * expected_std) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py index 1e602a01..c4033a8c 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py @@ -19,11 +19,13 @@ # Tensorflow aliases. -PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]] +Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor] + +PackedTensors = Union[Tensor, Iterable[Tensor], Mapping[str, Tensor]] InputTensors = PackedTensors -OutputTensors = Union[tf.Tensor, Iterable[tf.Tensor]] +OutputTensors = Union[Tensor, Iterable[Tensor]] BatchSize = Union[int, tf.Tensor] diff --git a/tensorflow_privacy/privacy/keras_models/BUILD b/tensorflow_privacy/privacy/keras_models/BUILD index a4ae42c8..55334ab1 100644 --- a/tensorflow_privacy/privacy/keras_models/BUILD +++ b/tensorflow_privacy/privacy/keras_models/BUILD @@ -17,7 +17,7 @@ py_library( "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils", - "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", + "//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils", ], ) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 4879b3bb..36fa6a1f 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -18,6 +18,8 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils +from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils + _PRIVATIZED_LOSS_NAME = 'privatized_loss' @@ -103,6 +105,10 @@ def __init__( num_microbatches=None, use_xla=True, layer_registry=None, + use_sparse_noise=False, + sparse_selection_ratio=None, + sparse_selection_threshold=None, + sparse_selection_layer_registry=None, *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args **kwargs, ): @@ -117,6 +123,17 @@ def __init__( help compute gradient norms quickly. See `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for more details. + use_sparse_noise: If `True`, uses partition selection and sparse noise + for privatizing sparse gradients for layers in + `sparse_selection_layer_registry`. + sparse_selection_ratio: The ratio of how the noise is split between + partition selection and gradient noise. + sparse_selection_threshold: The threshold to use for private partition + selection. + sparse_selection_layer_registry: A `LayerRegistry` instance containing + functions that help compute contribution counts for sparse layers. See + `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` + for more details. *args: These will be passed on to the base class `__init__` method. **kwargs: These will be passed on to the base class `__init__` method. """ @@ -126,6 +143,11 @@ def __init__( self._layer_registry = layer_registry self._clipping_loss = None + self._use_sparse_noise = use_sparse_noise + self._sparse_selection_ratio = sparse_selection_ratio + self._sparse_selection_threshold = sparse_selection_threshold + self._sparse_selection_layer_registry = sparse_selection_layer_registry + # Given that `num_microbatches` was added as an argument after the fact, # this check helps detect unintended calls to the earlier API. # In particular, boolean values supplied to `use_xla` in the earlier API @@ -273,27 +295,84 @@ def train_step(self, data): # trick, and uses these norms to clip the per-example gradients. # NOTE: Reshaping of the input according to the effective number of # microbatches is done here. + tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) + + registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( + tape=tape, + layer_registry=self._layer_registry, + sparse_noise_layer_registry=self._sparse_selection_layer_registry, + num_microbatches=num_microbatches, + ) + layer_grad_vars, registry_fn_outputs_list = ( + gradient_clipping_utils.model_forward_backward_pass( + tape=tape, + input_model=self, + x_batch=x, + y_batch=y, + registry_generator_fn=registry_generator_fn, + weight_batch=weights, + num_microbatches=num_microbatches, + trainable_vars=self.trainable_variables, + ) + ) clipped_grads, y_pred, clipping_loss = ( clip_grads.compute_clipped_gradients_and_outputs( input_model=self, + registry_fn_outputs_list=registry_fn_outputs_list, + layer_grad_vars=layer_grad_vars, x_batch=x, y_batch=y, weight_batch=weights, l2_norm_clip=self._l2_norm_clip, - layer_registry=self._layer_registry, num_microbatches=self._num_microbatches, clipping_loss=self._clipping_loss, ) ) output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss - if self._noise_multiplier > 0: + + noise_multiplier, noise_multiplier_sparse = self._noise_multiplier, None + contribution_counts = None + if self._use_sparse_noise and self._sparse_selection_layer_registry: + logging.info('Using sparse noise.') + + varname_to_contribution_counts_fns = ( + sparse_noise_utils.extract_varname_to_contribution_counts_fns( + registry_fn_outputs_list, + self.trainable_variables, + ) + ) + contribution_counts = sparse_noise_utils.get_contribution_counts( + self.trainable_variables, + layer_grad_vars, + varname_to_contribution_counts_fns, + ) + + noise_multiplier_sparse, noise_multiplier = ( + sparse_noise_utils.split_noise_multiplier( + noise_multiplier, + self._sparse_selection_ratio, + contribution_counts, + ) + ) + logging.info( + 'Split noise multiplier for gradient noise: %s and partition' + ' selection: %s', + noise_multiplier, + noise_multiplier_sparse, + ) + + if noise_multiplier > 0: grads = gradient_clipping_utils.add_aggregate_noise( clipped_grads, num_microbatches, self._l2_norm_clip, - self._noise_multiplier, + noise_multiplier, loss_reduction=None, loss_model=self, + use_sparse_noise=self._use_sparse_noise, + sparse_noise_multiplier=noise_multiplier_sparse, + sparse_selection_threshold=self._sparse_selection_threshold, + sparse_contribution_counts=contribution_counts, ) else: grads = clipped_grads diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD b/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD index 8698b511..e003baa7 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD @@ -5,6 +5,7 @@ licenses(["notice"]) py_library( name = "sparse_noise_utils", srcs = ["sparse_noise_utils.py"], + deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils"], ) py_test( diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py index 839a5591..7a70dca3 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py @@ -16,10 +16,12 @@ For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357. """ +import collections from typing import Mapping, Optional, Sequence from scipy import stats import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils import tensorflow_probability as tfp @@ -288,6 +290,43 @@ def add_sparse_gradient_noise( ) +def extract_varname_to_contribution_counts_fns( + registry_fn_outputs_list: list[ + gradient_clipping_utils.RegistryGeneratorFunctionOutput + ], + trainable_vars: list[tf.Variable], +) -> dict[str, list[tf.Tensor | None]]: + """Extracts a map of contribution count fns from generator outputs. + + TODO. Move to sparse_noise_utils.py. + + Args: + registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput` + instances returned by + `gradient_clipping_utils.model_forward_backward_pass`. + trainable_vars: A list of trainable variables. + + Returns: + A `dict` from varname to contribution counts functions + """ + if trainable_vars is not None: + # Create a set using `ref()` for fast set membership check. tf.Variable + # itself is not hashable. + trainable_vars = set([v.ref() for v in trainable_vars]) + + varname_to_contribution_counts_fns = collections.defaultdict(list) + for registry_fn_output in registry_fn_outputs_list: + if trainable_vars is None or any( + w.ref() in trainable_vars + for w in registry_fn_output.layer_trainable_weights + ): + if registry_fn_output.varname_to_count_contribution_fn is not None: + varname_to_contribution_counts_fns.update( + registry_fn_output.varname_to_count_contribution_fn + ) + return varname_to_contribution_counts_fns + + def get_contribution_counts( trainable_vars: list[tf.Variable], grads: list[tf.Tensor], diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py index 82c7e762..5dcad197 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py @@ -19,7 +19,7 @@ InputArgs = Sequence[Any] InputKwargs = Mapping[str, Any] -SparseGradient = tf.IndexedSlices +SparseGradient = tf.IndexedSlices | tf.SparseTensor ContributionCountHistogram = tf.SparseTensor ContributionCountHistogramFn = Callable[ [SparseGradient], Mapping[str, ContributionCountHistogram]