Skip to content

Commit

Permalink
Sparsity Preserving DP-SGD in TF Privacy
Browse files Browse the repository at this point in the history
Add support for adding sparsity preserving noise in add_aggregate_noise

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 660521257
  • Loading branch information
tensorflower-gardener committed Aug 8, 2024
1 parent 09c6875 commit a02c397
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 10 deletions.
1 change: 1 addition & 0 deletions tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ py_library(
py_library(
name = "noise_utils",
srcs = ["noise_utils.py"],
deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils"],
)

py_test(
Expand Down
89 changes: 79 additions & 10 deletions tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,21 @@
"""Utility functions that help in adding noise to gradients."""

from collections.abc import Sequence
import dataclasses
from typing import Literal, Optional

from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils


@dataclasses.dataclass
class SparsityPreservingNoiseConfig:
"""Configuration for adding noise to gradients."""

sparse_noise_multiplier: float = 0.0
sparse_selection_threshold: int = 0
sparse_contribution_counts: Optional[Sequence[tf.SparseTensor]] = None


def _infer_loss_reduction_type(model: tf.keras.Model):
Expand All @@ -44,21 +55,53 @@ def _infer_loss_reduction_type(model: tf.keras.Model):
)


def _add_dense_aggregate_noise(
grad: tf.Tensor,
noise_multiplier: float,
sensitivity: float,
) -> tf.Tensor:
"""Adds dense noise to a dense gradient."""
return grad + tf.random.normal(
tf.shape(grad), mean=0.0, stddev=noise_multiplier * sensitivity
)


def _add_sparse_aggregate_noise(
grad: tf.IndexedSlices,
contribution_counts: tf.SparseTensor,
noise_multiplier: float,
noise_multiplier_sparse: float,
sensitivity: float,
sparse_selection_threshold: int,
) -> tf.IndexedSlices:
"""Adds sparse noise to a sparse gradient."""
return sparse_noise_utils.add_sparse_noise(
grad=grad,
contribution_counts=contribution_counts,
noise_multiplier=noise_multiplier,
noise_multiplier_sparse=noise_multiplier_sparse,
l2_norm_clip=sensitivity,
threshold=sparse_selection_threshold,
)


def add_aggregate_noise(
clipped_grads: list[tf.Tensor],
clipped_grads: list[tf.Tensor | tf.IndexedSlices],
batch_size: tf.Tensor,
l2_norm_clip: float,
noise_multiplier: float,
loss_reduction: Optional[Literal['mean', 'sum']] = None,
loss_model: Optional[tf.keras.Model] = None,
) -> Sequence[tf.Tensor]:
sparse_noise_config: Optional[SparsityPreservingNoiseConfig] = None,
) -> Sequence[tf.Tensor | tf.IndexedSlices]:
"""Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the
input model's loss function.
Args:
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
clipped_grads: A list of `tf.Tensor`s or `tf.IndexedSlices`s representing
the clipped gradients.
batch_size: The batch size. Used for normalizing the noise when
`loss_reduction` is 'sum'.
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
Expand All @@ -68,11 +111,14 @@ def add_aggregate_noise(
aggregation type must be inferred from `input_model.loss`.
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
strategy from if `loss_reduction` is `None`.
sparse_noise_config: A `SparsityPreservingNoiseConfig` instance containing
the configuration for adding sparse noise. If None, all noise added is
dense.
Returns:
A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction
strategy of the loss function).
amount of Gaussian or sparse Gaussain noise added to them (depending on
the reduction strategy of the loss function).
Raises:
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
Expand Down Expand Up @@ -103,13 +149,36 @@ def add_aggregate_noise(
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
)

if sparse_noise_config is None:
sparse_contribution_counts = tf.nest.map_structure(
lambda x: None, clipped_grads
)
else:
sparse_contribution_counts = sparse_noise_config.sparse_contribution_counts

scale = l2_norm_clip
if loss_reduction == 'mean':
scale /= tf.cast(batch_size, tf.float32)

def add_noise(g):
return g + tf.random.normal(
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
)
def add_noise(grad, contribution_counts):
if (
sparse_noise_config is not None
and isinstance(grad, tf.IndexedSlices)
and contribution_counts is not None
):
return _add_sparse_aggregate_noise(
grad=grad,
contribution_counts=contribution_counts,
noise_multiplier=noise_multiplier,
noise_multiplier_sparse=sparse_noise_config.sparse_noise_multiplier,
sensitivity=scale,
sparse_selection_threshold=sparse_noise_config.sparse_selection_threshold,
)
else:
return _add_dense_aggregate_noise(
grad=grad, noise_multiplier=noise_multiplier, sensitivity=scale
)

return tf.nest.map_structure(add_noise, clipped_grads)
return tf.nest.map_structure(
add_noise, clipped_grads, sparse_contribution_counts
)
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,95 @@ def test_noise_is_computed_correctly(
computed_std = np.std(noised_grads[0] - clipped_grads[0])
expected_std = l2_norm_clip * noise_multiplier * scale
self.assertNear(computed_std, expected_std, 0.1 * expected_std)

@parameterized.product(
l2_norm_clip=[3.0, 5.0],
noise_multiplier=[2.0, 4.0],
sparse_noise_multiplier=[1.0],
batch_size=[1, 2, 10],
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
noise_fn_reduction=[None, 'mean', 'sum'],
)
def test_sparse_noise_is_computed_correctly(
self,
l2_norm_clip,
noise_multiplier,
sparse_noise_multiplier,
batch_size,
model_fn_reduction,
noise_fn_reduction,
):
# Skip invalid combinations.
if model_fn_reduction is None and noise_fn_reduction is None:
return
if model_fn_reduction is not None and noise_fn_reduction is not None:
return
# Make an simple model container for storing the loss.
if model_fn_reduction is not None:
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
linear_model.compile(
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
)
else:
linear_model = None
# The main computation is done on a deterministic dummy vector.
num_units = 100
dense_grad = tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
sparse_grad = tf.IndexedSlices(
values=tf.ones((3, 4)),
indices=tf.constant([0, 3, 5]),
dense_shape=tf.constant([8, 4]),
)
sparse_grad_contribution_counts = tf.SparseTensor(
indices=[[0], [3], [5]],
values=[10.0, 10.0, 20.0],
dense_shape=[8],
)

sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
sparse_noise_multiplier=sparse_noise_multiplier,
sparse_selection_threshold=8,
sparse_contribution_counts=[None, sparse_grad_contribution_counts],
)

sparse_noised_grad, dense_noised_grad = noise_utils.add_aggregate_noise(
clipped_grads=[dense_grad, sparse_grad],
batch_size=batch_size,
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
loss_model=linear_model,
sparse_noise_config=sparse_noise_config,
)
self.assertContainsSubset(
sparse_grad.indices.numpy().tolist(),
sparse_noised_grad.indices.numpy().tolist(),
)
sparse_noised_grad_dense = tf.scatter_nd(
tf.reshape(sparse_noised_grad.indices, (-1, 1)),
sparse_noised_grad.values,
shape=(8, 4),
).numpy()
sparse_noised_grad_valid_indices = sparse_noised_grad_dense[
sparse_grad.indices.numpy()
]
sparse_grad_values = sparse_grad.values.numpy()
self.assertTrue(
np.all(
np.not_equal(sparse_noised_grad_valid_indices, sparse_grad_values)
)
)
scale = (
1.0
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
else 1.0 / batch_size
)
# The only measure that varies is the standard deviation of the variation.
expected_std = l2_norm_clip * noise_multiplier * scale

sparse_computed_std = np.std(
sparse_noised_grad_valid_indices - sparse_grad_values
)
self.assertNear(sparse_computed_std, expected_std, 0.1 * expected_std)

dense_computed_std = np.std(dense_noised_grad - dense_grad)
self.assertNear(dense_computed_std, expected_std, 0.1 * expected_std)

0 comments on commit a02c397

Please sign in to comment.