Skip to content

Commit

Permalink
Sparsity Preserving DP-SGD in TF Privacy
Browse files Browse the repository at this point in the history
Add support for calculating contribution counts to registry function for sparsity preserving noise.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 660525768
  • Loading branch information
tensorflower-gardener committed Aug 12, 2024
1 parent 09c6875 commit 8f8130d
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 11 deletions.
9 changes: 8 additions & 1 deletion tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ py_library(
":common_manip_utils",
":layer_registry",
":type_aliases",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases",
],
)

Expand All @@ -55,7 +57,11 @@ py_test(
python_version = "PY3",
shard_count = 8,
srcs_version = "PY3",
deps = [":gradient_clipping_utils"],
deps = [
":gradient_clipping_utils",
":layer_registry",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
],
)

py_library(
Expand Down Expand Up @@ -84,6 +90,7 @@ py_library(
py_library(
name = "noise_utils",
srcs = ["noise_utils.py"],
deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils"],
)

py_test(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def compute_gradient_norms(
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=layer_registry,
sparse_noise_layer_registry=None,
num_microbatches=num_microbatches,
)
layer_grad_vars, generator_outputs_list = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _run_model_forward_backward_pass(
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=layer_registry.make_default_layer_registry(),
sparse_noise_layer_registry=None,
num_microbatches=None,
)
layer_grad_vars, registry_fn_outputs_list = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases as sn_type_aliases


@dataclasses.dataclass(frozen=True)
class RegistryGeneratorFunctionOutput:
layer_id: str
layer_vars: Optional[Sequence[tf.Variable]]
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
varname_to_count_contribution_fn: Optional[
dict[str, sn_type_aliases.ContributionCountHistogramFn]
]
layer_trainable_weights: Optional[Sequence[tf.Variable]]


Expand All @@ -46,6 +51,7 @@ def has_internal_compute_graph(input_object: Any):
def get_registry_generator_fn(
tape: tf.GradientTape,
layer_registry: lr.LayerRegistry,
sparse_noise_layer_registry: snlr.LayerRegistry,
num_microbatches: Optional[type_aliases.BatchSize] = None,
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
"""Creates the generator function for `model_forward_backward_pass()`.
Expand All @@ -58,6 +64,10 @@ def get_registry_generator_fn(
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable
sparse_noise_layer_registry: A `LayerRegistry` instance containing functions
that help compute contribution counts for sparse noise. See
`tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
more details.
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
Expand All @@ -83,6 +93,16 @@ def registry_generator_fn(layer_instance, args, kwargs):
'be used for efficient gradient clipping.'
% layer_instance.__class__.__name__
)
varname_to_count_contribution_fn = None
if sparse_noise_layer_registry and sparse_noise_layer_registry.is_elem(
layer_instance
):
count_contribution_registry_fn = sparse_noise_layer_registry.lookup(
layer_instance
)
varname_to_count_contribution_fn = count_contribution_registry_fn(
layer_instance, args, kwargs, num_microbatches
)
registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, kwargs, tape, num_microbatches
Expand All @@ -91,6 +111,7 @@ def registry_generator_fn(layer_instance, args, kwargs):
layer_id=str(id(layer_instance)),
layer_vars=layer_vars,
layer_sqr_norm_fn=layer_sqr_norm_fn,
varname_to_count_contribution_fn=varname_to_count_contribution_fn,
layer_trainable_weights=layer_instance.trainable_weights,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr


# ==============================================================================
Expand Down Expand Up @@ -175,5 +177,92 @@ def test_new_custom_layer_spec(self):
)


class RegistryGeneratorFnTest(tf.test.TestCase, parameterized.TestCase):

def _get_sparse_layer_registry(self):
def count_contribution_fn(_):
return None

def registry_fn(*_):
return {'var': count_contribution_fn}

registry = snlr.LayerRegistry()
registry.insert(tf.keras.layers.Embedding, registry_fn)
return registry, count_contribution_fn

def _get_layer_registry(self):
var = tf.Variable(1.0)
output = tf.ones((1, 1))

def sqr_norm_fn(_):
return None

def registry_fn(*_):
return [var], output, sqr_norm_fn

registry = lr.LayerRegistry()
registry.insert(tf.keras.layers.Embedding, registry_fn)
registry.insert(tf.keras.layers.Dense, registry_fn)
return registry, var, output, sqr_norm_fn

def test_registry_generator_fn(self):
inputs = tf.constant([[0, 1]])
model = tf.keras.Sequential([
tf.keras.layers.Embedding(10, 1),
tf.keras.layers.Dense(1),
])

sparse_layer_registry, count_contribution_fn = (
self._get_sparse_layer_registry()
)
layer_registry, var, output, sqr_norm_fn = self._get_layer_registry()
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape=tf.GradientTape(),
layer_registry=layer_registry,
sparse_noise_layer_registry=sparse_layer_registry,
num_microbatches=None,
)
embedding_layer = model.layers[0]
out, embedding_registry_generator_fn_output = registry_generator_fn(
embedding_layer,
[inputs],
{},
)
expected_embedding_registry_generator_fn_output = (
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id=str(id(embedding_layer)),
layer_vars=[var],
layer_sqr_norm_fn=sqr_norm_fn,
varname_to_count_contribution_fn={'var': count_contribution_fn},
layer_trainable_weights=embedding_layer.trainable_weights,
)
)
self.assertEqual(
embedding_registry_generator_fn_output,
expected_embedding_registry_generator_fn_output,
)
self.assertEqual(out, output)
dense_layer = model.layers[1]
out, dense_registry_generator_fn_output = registry_generator_fn(
dense_layer,
[inputs],
{},
)
expected_dense_registry_generator_fn_output = (
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id=str(id(dense_layer)),
layer_vars=[var],
layer_sqr_norm_fn=sqr_norm_fn,
varname_to_count_contribution_fn=None,
layer_trainable_weights=dense_layer.trainable_weights,
)
)
self.assertEqual(
dense_registry_generator_fn_output,
expected_dense_registry_generator_fn_output,
)
self.assertEqual(out, output)


if __name__ == '__main__':
tf.test.main()
89 changes: 79 additions & 10 deletions tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,21 @@
"""Utility functions that help in adding noise to gradients."""

from collections.abc import Sequence
import dataclasses
from typing import Literal, Optional

from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils


@dataclasses.dataclass
class SparsityPreservingNoiseConfig:
"""Configuration for adding noise to gradients."""

sparse_noise_multiplier: float = 0.0
sparse_selection_threshold: int = 0
sparse_contribution_counts: Optional[Sequence[tf.SparseTensor]] = None


def _infer_loss_reduction_type(model: tf.keras.Model):
Expand All @@ -44,21 +55,53 @@ def _infer_loss_reduction_type(model: tf.keras.Model):
)


def _add_dense_aggregate_noise(
grad: tf.Tensor,
noise_multiplier: float,
sensitivity: float,
) -> tf.Tensor:
"""Adds dense noise to a dense gradient."""
return grad + tf.random.normal(
tf.shape(grad), mean=0.0, stddev=noise_multiplier * sensitivity
)


def _add_sparse_aggregate_noise(
grad: tf.IndexedSlices,
contribution_counts: tf.SparseTensor,
noise_multiplier: float,
noise_multiplier_sparse: float,
sensitivity: float,
sparse_selection_threshold: int,
) -> tf.IndexedSlices:
"""Adds sparse noise to a sparse gradient."""
return sparse_noise_utils.add_sparse_noise(
grad=grad,
contribution_counts=contribution_counts,
noise_multiplier=noise_multiplier,
noise_multiplier_sparse=noise_multiplier_sparse,
l2_norm_clip=sensitivity,
threshold=sparse_selection_threshold,
)


def add_aggregate_noise(
clipped_grads: list[tf.Tensor],
clipped_grads: list[tf.Tensor | tf.IndexedSlices],
batch_size: tf.Tensor,
l2_norm_clip: float,
noise_multiplier: float,
loss_reduction: Optional[Literal['mean', 'sum']] = None,
loss_model: Optional[tf.keras.Model] = None,
) -> Sequence[tf.Tensor]:
sparse_noise_config: Optional[SparsityPreservingNoiseConfig] = None,
) -> Sequence[tf.Tensor | tf.IndexedSlices]:
"""Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the
input model's loss function.
Args:
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
clipped_grads: A list of `tf.Tensor`s or `tf.IndexedSlices`s representing
the clipped gradients.
batch_size: The batch size. Used for normalizing the noise when
`loss_reduction` is 'sum'.
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
Expand All @@ -68,11 +111,14 @@ def add_aggregate_noise(
aggregation type must be inferred from `input_model.loss`.
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
strategy from if `loss_reduction` is `None`.
sparse_noise_config: A `SparsityPreservingNoiseConfig` instance containing
the configuration for adding sparse noise. If None, all noise added is
dense.
Returns:
A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction
strategy of the loss function).
amount of Gaussian or sparse Gaussain noise added to them (depending on
the reduction strategy of the loss function).
Raises:
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
Expand Down Expand Up @@ -103,13 +149,36 @@ def add_aggregate_noise(
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
)

if sparse_noise_config is None:
sparse_contribution_counts = tf.nest.map_structure(
lambda x: None, clipped_grads
)
else:
sparse_contribution_counts = sparse_noise_config.sparse_contribution_counts

scale = l2_norm_clip
if loss_reduction == 'mean':
scale /= tf.cast(batch_size, tf.float32)

def add_noise(g):
return g + tf.random.normal(
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
)
def add_noise(grad, contribution_counts):
if (
sparse_noise_config is not None
and isinstance(grad, tf.IndexedSlices)
and contribution_counts is not None
):
return _add_sparse_aggregate_noise(
grad=grad,
contribution_counts=contribution_counts,
noise_multiplier=noise_multiplier,
noise_multiplier_sparse=sparse_noise_config.sparse_noise_multiplier,
sensitivity=scale,
sparse_selection_threshold=sparse_noise_config.sparse_selection_threshold,
)
else:
return _add_dense_aggregate_noise(
grad=grad, noise_multiplier=noise_multiplier, sensitivity=scale
)

return tf.nest.map_structure(add_noise, clipped_grads)
return tf.nest.map_structure(
add_noise, clipped_grads, sparse_contribution_counts
)
Loading

0 comments on commit 8f8130d

Please sign in to comment.