From ec59876847dd6cec458689f87870a85c36fec592 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Aug 2024 13:52:14 -0700 Subject: [PATCH] Sparsity Preserving DP-SGD in TF Privacy Add support for calculating contribution counts to registry function for sparsity preserving noise. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 660525768 --- .../privacy/fast_gradient_clipping/BUILD | 9 +- .../fast_gradient_clipping/clip_grads.py | 1 + .../fast_gradient_clipping/clip_grads_test.py | 1 + .../gradient_clipping_utils.py | 21 +++++ .../gradient_clipping_utils_test.py | 89 ++++++++++++++++++ .../fast_gradient_clipping/noise_utils.py | 89 ++++++++++++++++-- .../noise_utils_test.py | 92 +++++++++++++++++++ .../privacy/keras_models/dp_keras_model.py | 1 + 8 files changed, 292 insertions(+), 11 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index f5b920f8..088294eb 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -46,6 +46,8 @@ py_library( ":common_manip_utils", ":layer_registry", ":type_aliases", + "//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry", + "//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases", ], ) @@ -55,7 +57,11 @@ py_test( python_version = "PY3", shard_count = 8, srcs_version = "PY3", - deps = [":gradient_clipping_utils"], + deps = [ + ":gradient_clipping_utils", + ":layer_registry", + "//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry", + ], ) py_library( @@ -84,6 +90,7 @@ py_library( py_library( name = "noise_utils", srcs = ["noise_utils.py"], + deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils"], ) py_test( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index e31f1781..be7d92f9 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -164,6 +164,7 @@ def compute_gradient_norms( registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( tape=tape, layer_registry=layer_registry, + sparse_noise_layer_registry=None, num_microbatches=num_microbatches, ) layer_grad_vars, generator_outputs_list = ( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index 7b91461c..6ec47941 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -132,6 +132,7 @@ def _run_model_forward_backward_pass( registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( tape=tape, layer_registry=layer_registry.make_default_layer_registry(), + sparse_noise_layer_registry=None, num_microbatches=None, ) layer_grad_vars, registry_fn_outputs_list = ( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index bac323ce..989a69c2 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -22,6 +22,8 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr +from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases as sn_type_aliases @dataclasses.dataclass(frozen=True) @@ -29,6 +31,9 @@ class RegistryGeneratorFunctionOutput: layer_id: str layer_vars: Optional[Sequence[tf.Variable]] layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction] + varname_to_count_contribution_fn: Optional[ + dict[str, sn_type_aliases.ContributionCountHistogramFn] + ] layer_trainable_weights: Optional[Sequence[tf.Variable]] @@ -46,6 +51,7 @@ def has_internal_compute_graph(input_object: Any): def get_registry_generator_fn( tape: tf.GradientTape, layer_registry: lr.LayerRegistry, + sparse_noise_layer_registry: snlr.LayerRegistry, num_microbatches: Optional[type_aliases.BatchSize] = None, ) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]: """Creates the generator function for `model_forward_backward_pass()`. @@ -58,6 +64,10 @@ def get_registry_generator_fn( `output` is the pre-activator tensor, `sqr_grad_norms` is related to the squared norms of a layer's pre-activation tensor, and `vars` are relevant trainable + sparse_noise_layer_registry: A `LayerRegistry` instance containing functions + that help compute contribution counts for sparse noise. See + `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for + more details. num_microbatches: An optional number or scalar `tf.Tensor` for the number of microbatches. If not None, indicates that the loss is grouped into num_microbatches (in this case, the batch dimension needs to be a multiple @@ -83,6 +93,16 @@ def registry_generator_fn(layer_instance, args, kwargs): 'be used for efficient gradient clipping.' % layer_instance.__class__.__name__ ) + varname_to_count_contribution_fn = None + if sparse_noise_layer_registry and sparse_noise_layer_registry.is_elem( + layer_instance + ): + count_contribution_registry_fn = sparse_noise_layer_registry.lookup( + layer_instance + ) + varname_to_count_contribution_fn = count_contribution_registry_fn( + layer_instance, args, kwargs, num_microbatches + ) registry_fn = layer_registry.lookup(layer_instance) (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( layer_instance, args, kwargs, tape, num_microbatches @@ -91,6 +111,7 @@ def registry_generator_fn(layer_instance, args, kwargs): layer_id=str(id(layer_instance)), layer_vars=layer_vars, layer_sqr_norm_fn=layer_sqr_norm_fn, + varname_to_count_contribution_fn=varname_to_count_contribution_fn, layer_trainable_weights=layer_instance.trainable_weights, ) else: diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py index 7069273d..0535e011 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py @@ -17,6 +17,8 @@ from absl.testing import parameterized import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr +from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr # ============================================================================== @@ -175,5 +177,92 @@ def test_new_custom_layer_spec(self): ) +class RegistryGeneratorFnTest(tf.test.TestCase, parameterized.TestCase): + + def _get_sparse_layer_registry(self): + def count_contribution_fn(_): + return None + + def registry_fn(*_): + return {'var': count_contribution_fn} + + registry = snlr.LayerRegistry() + registry.insert(tf.keras.layers.Embedding, registry_fn) + return registry, count_contribution_fn + + def _get_layer_registry(self): + var = tf.Variable(1.0) + output = tf.ones((1, 1)) + + def sqr_norm_fn(_): + return None + + def registry_fn(*_): + return [var], output, sqr_norm_fn + + registry = lr.LayerRegistry() + registry.insert(tf.keras.layers.Embedding, registry_fn) + registry.insert(tf.keras.layers.Dense, registry_fn) + return registry, var, output, sqr_norm_fn + + def test_registry_generator_fn(self): + inputs = tf.constant([[0, 1]]) + model = tf.keras.Sequential([ + tf.keras.layers.Embedding(10, 1), + tf.keras.layers.Dense(1), + ]) + + sparse_layer_registry, count_contribution_fn = ( + self._get_sparse_layer_registry() + ) + layer_registry, var, output, sqr_norm_fn = self._get_layer_registry() + registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( + tape=tf.GradientTape(), + layer_registry=layer_registry, + sparse_noise_layer_registry=sparse_layer_registry, + num_microbatches=None, + ) + embedding_layer = model.layers[0] + out, embedding_registry_generator_fn_output = registry_generator_fn( + embedding_layer, + [inputs], + {}, + ) + expected_embedding_registry_generator_fn_output = ( + gradient_clipping_utils.RegistryGeneratorFunctionOutput( + layer_id=str(id(embedding_layer)), + layer_vars=[var], + layer_sqr_norm_fn=sqr_norm_fn, + varname_to_count_contribution_fn={'var': count_contribution_fn}, + layer_trainable_weights=embedding_layer.trainable_weights, + ) + ) + self.assertEqual( + embedding_registry_generator_fn_output, + expected_embedding_registry_generator_fn_output, + ) + self.assertEqual(out, output) + dense_layer = model.layers[1] + out, dense_registry_generator_fn_output = registry_generator_fn( + dense_layer, + [inputs], + {}, + ) + expected_dense_registry_generator_fn_output = ( + gradient_clipping_utils.RegistryGeneratorFunctionOutput( + layer_id=str(id(dense_layer)), + layer_vars=[var], + layer_sqr_norm_fn=sqr_norm_fn, + varname_to_count_contribution_fn=None, + layer_trainable_weights=dense_layer.trainable_weights, + ) + ) + self.assertEqual( + dense_registry_generator_fn_output, + expected_dense_registry_generator_fn_output, + ) + self.assertEqual(out, output) + + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py index 7dd2f157..f1ace9ed 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py @@ -14,10 +14,21 @@ """Utility functions that help in adding noise to gradients.""" from collections.abc import Sequence +import dataclasses from typing import Literal, Optional from absl import logging import tensorflow as tf +from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils + + +@dataclasses.dataclass +class SparsityPreservingNoiseConfig: + """Configuration for adding noise to gradients.""" + + sparse_noise_multiplier: float = 0.0 + sparse_selection_threshold: int = 0 + sparse_contribution_counts: Optional[Sequence[tf.SparseTensor]] = None def _infer_loss_reduction_type(model: tf.keras.Model): @@ -44,21 +55,53 @@ def _infer_loss_reduction_type(model: tf.keras.Model): ) +def _add_dense_aggregate_noise( + grad: tf.Tensor, + noise_multiplier: float, + sensitivity: float, +) -> tf.Tensor: + """Adds dense noise to a dense gradient.""" + return grad + tf.random.normal( + tf.shape(grad), mean=0.0, stddev=noise_multiplier * sensitivity + ) + + +def _add_sparse_aggregate_noise( + grad: tf.IndexedSlices, + contribution_counts: tf.SparseTensor, + noise_multiplier: float, + noise_multiplier_sparse: float, + sensitivity: float, + sparse_selection_threshold: int, +) -> tf.IndexedSlices: + """Adds sparse noise to a sparse gradient.""" + return sparse_noise_utils.add_sparse_noise( + grad=grad, + contribution_counts=contribution_counts, + noise_multiplier=noise_multiplier, + noise_multiplier_sparse=noise_multiplier_sparse, + l2_norm_clip=sensitivity, + threshold=sparse_selection_threshold, + ) + + def add_aggregate_noise( - clipped_grads: list[tf.Tensor], + clipped_grads: list[tf.Tensor | tf.IndexedSlices], batch_size: tf.Tensor, l2_norm_clip: float, noise_multiplier: float, loss_reduction: Optional[Literal['mean', 'sum']] = None, loss_model: Optional[tf.keras.Model] = None, -) -> Sequence[tf.Tensor]: + sparse_noise_config: Optional[SparsityPreservingNoiseConfig] = None, +) -> Sequence[tf.Tensor | tf.IndexedSlices]: """Adds noise to a collection of clipped gradients. The magnitude of the noise depends on the aggregation strategy of the input model's loss function. Args: - clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. + clipped_grads: A list of `tf.Tensor`s or `tf.IndexedSlices`s representing + the clipped gradients. batch_size: The batch size. Used for normalizing the noise when `loss_reduction` is 'sum'. l2_norm_clip: Clipping norm (max L2 norm of each gradient). @@ -68,11 +111,14 @@ def add_aggregate_noise( aggregation type must be inferred from `input_model.loss`. loss_model: An optional `tf.keras.Model` used to infer the loss reduction strategy from if `loss_reduction` is `None`. + sparse_noise_config: A `SparsityPreservingNoiseConfig` instance containing + the configuration for adding sparse noise. If None, all noise added is + dense. Returns: A list of tensors containing the clipped gradients, but with the right - amount of Gaussian noise added to them (depending on the reduction - strategy of the loss function). + amount of Gaussian or sparse Gaussain noise added to them (depending on + the reduction strategy of the loss function). Raises: ValueError: If both `loss_model` and `loss_reduction` are `None` or if @@ -103,13 +149,36 @@ def add_aggregate_noise( 'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.' ) + if sparse_noise_config is None: + sparse_contribution_counts = tf.nest.map_structure( + lambda x: None, clipped_grads + ) + else: + sparse_contribution_counts = sparse_noise_config.sparse_contribution_counts + scale = l2_norm_clip if loss_reduction == 'mean': scale /= tf.cast(batch_size, tf.float32) - def add_noise(g): - return g + tf.random.normal( - tf.shape(g), mean=0.0, stddev=noise_multiplier * scale - ) + def add_noise(grad, contribution_counts): + if ( + sparse_noise_config is not None + and isinstance(grad, tf.IndexedSlices) + and contribution_counts is not None + ): + return _add_sparse_aggregate_noise( + grad=grad, + contribution_counts=contribution_counts, + noise_multiplier=noise_multiplier, + noise_multiplier_sparse=sparse_noise_config.sparse_noise_multiplier, + sensitivity=scale, + sparse_selection_threshold=sparse_noise_config.sparse_selection_threshold, + ) + else: + return _add_dense_aggregate_noise( + grad=grad, noise_multiplier=noise_multiplier, sensitivity=scale + ) - return tf.nest.map_structure(add_noise, clipped_grads) + return tf.nest.map_structure( + add_noise, clipped_grads, sparse_contribution_counts + ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py index 9d5cac3b..880b8d3a 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py @@ -70,3 +70,95 @@ def test_noise_is_computed_correctly( computed_std = np.std(noised_grads[0] - clipped_grads[0]) expected_std = l2_norm_clip * noise_multiplier * scale self.assertNear(computed_std, expected_std, 0.1 * expected_std) + + @parameterized.product( + l2_norm_clip=[3.0, 5.0], + noise_multiplier=[2.0, 4.0], + sparse_noise_multiplier=[1.0], + batch_size=[1, 2, 10], + model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'], + noise_fn_reduction=[None, 'mean', 'sum'], + ) + def test_sparse_noise_is_computed_correctly( + self, + l2_norm_clip, + noise_multiplier, + sparse_noise_multiplier, + batch_size, + model_fn_reduction, + noise_fn_reduction, + ): + # Skip invalid combinations. + if model_fn_reduction is None and noise_fn_reduction is None: + return + if model_fn_reduction is not None and noise_fn_reduction is not None: + return + # Make an simple model container for storing the loss. + if model_fn_reduction is not None: + linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)]) + linear_model.compile( + loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction) + ) + else: + linear_model = None + # The main computation is done on a deterministic dummy vector. + num_units = 100 + dense_grad = tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1) + sparse_grad = tf.IndexedSlices( + values=tf.ones((3, 4)), + indices=tf.constant([0, 3, 5]), + dense_shape=tf.constant([8, 4]), + ) + sparse_grad_contribution_counts = tf.SparseTensor( + indices=[[0], [3], [5]], + values=[10.0, 10.0, 20.0], + dense_shape=[8], + ) + + sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig( + sparse_noise_multiplier=sparse_noise_multiplier, + sparse_selection_threshold=8, + sparse_contribution_counts=[None, sparse_grad_contribution_counts], + ) + + sparse_noised_grad, dense_noised_grad = noise_utils.add_aggregate_noise( + clipped_grads=[dense_grad, sparse_grad], + batch_size=batch_size, + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + loss_model=linear_model, + sparse_noise_config=sparse_noise_config, + ) + self.assertContainsSubset( + sparse_grad.indices.numpy().tolist(), + sparse_noised_grad.indices.numpy().tolist(), + ) + sparse_noised_grad_dense = tf.scatter_nd( + tf.reshape(sparse_noised_grad.indices, (-1, 1)), + sparse_noised_grad.values, + shape=(8, 4), + ).numpy() + sparse_noised_grad_valid_indices = sparse_noised_grad_dense[ + sparse_grad.indices.numpy() + ] + sparse_grad_values = sparse_grad.values.numpy() + self.assertTrue( + np.all( + np.not_equal(sparse_noised_grad_valid_indices, sparse_grad_values) + ) + ) + scale = ( + 1.0 + if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum' + else 1.0 / batch_size + ) + # The only measure that varies is the standard deviation of the variation. + expected_std = l2_norm_clip * noise_multiplier * scale + + sparse_computed_std = np.std( + sparse_noised_grad_valid_indices - sparse_grad_values + ) + self.assertNear(sparse_computed_std, expected_std, 0.1 * expected_std) + + dense_computed_std = np.std(dense_noised_grad - dense_grad) + self.assertNear(dense_computed_std, expected_std, 0.1 * expected_std) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index b7104f4d..472f175f 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -280,6 +280,7 @@ def train_step(self, data): gradient_clipping_utils.get_registry_generator_fn( tape=tape, layer_registry=self._layer_registry, + sparse_noise_layer_registry=None, num_microbatches=num_microbatches, ) )