Skip to content

Commit

Permalink
Sparsity Preserving DP-SGD in TF Privacy
Browse files Browse the repository at this point in the history
Refactor utilities for adding noise into separate utility file.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 660475009
  • Loading branch information
tensorflower-gardener committed Aug 7, 2024
1 parent fc6f1dc commit 614c93f
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 158 deletions.
16 changes: 12 additions & 4 deletions tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ py_test(
python_version = "PY3",
shard_count = 8,
srcs_version = "PY3",
deps = [
":gradient_clipping_utils",
":type_aliases",
],
deps = [":gradient_clipping_utils"],
)

py_library(
Expand All @@ -83,6 +80,11 @@ py_library(
],
)

py_library(
name = "noise_utils",
srcs = ["noise_utils.py"],
)

py_test(
name = "clip_grads_test",
srcs = ["clip_grads_test.py"],
Expand All @@ -96,3 +98,9 @@ py_test(
":type_aliases",
],
)

py_test(
name = "noise_utils_test",
srcs = ["noise_utils_test.py"],
deps = [":noise_utils"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
"""Utility functions that help in the computation of per-example gradient norms."""

from collections.abc import Sequence, Set
from typing import Any, Literal, Optional
from typing import Any, Optional

from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
Expand Down Expand Up @@ -144,101 +143,6 @@ def all_trainable_layers_are_registered(
return True


def _infer_loss_reduction_type(model: tf.keras.Model):
"""Infers what type of loss reduction is being performed."""
model_loss = model.loss
if isinstance(model_loss, tf.keras.losses.Loss):
return model_loss.reduction
elif isinstance(model.loss, dict):
reductions = set()
compiled_loss = model.compiled_loss
if compiled_loss is None:
raise ValueError('Model must be compiled for adding noise')
new_config_list = compiled_loss.get_config()['losses']
for loss_config in new_config_list:
reductions.add(loss_config['config']['reduction'])
if len(reductions) > 1:
raise ValueError(
'Reductions in models with multiple losses must all be the same'
)
return reductions.pop()
else:
raise ValueError(
'Unsupported type for adding noise: {}'.format(type(model_loss))
)


def add_aggregate_noise(
clipped_grads: list[tf.Tensor],
batch_size: tf.Tensor,
l2_norm_clip: float,
noise_multiplier: float,
loss_reduction: Optional[Literal['mean', 'sum']] = None,
loss_model: Optional[tf.keras.Model] = None,
) -> Sequence[tf.Tensor]:
"""Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the
input model's loss function.
Args:
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
batch_size: The batch size. Used for normalizing the noise when
`loss_reduction` is 'sum'.
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
loss_reduction: An string description of how the loss is reduced over
examples. Currently supports 'mean' and 'sum'. If `None`, then the
aggregation type must be inferred from `input_model.loss`.
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
strategy from if `loss_reduction` is `None`.
Returns:
A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction
strategy of the loss function).
Raises:
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
they are both not `None`.
"""
if loss_reduction is None and loss_model is None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were `None`.'
)
if loss_reduction is not None and loss_model is not None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were not `None`.'
)

if loss_reduction is None and loss_model is not None:
implicit_mean_reductions = [
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]
model_reduction = _infer_loss_reduction_type(loss_model)
loss_reduction = (
'mean' if model_reduction in implicit_mean_reductions else 'sum'
)
if model_reduction == tf.keras.losses.Reduction.AUTO:
logging.info(
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
)

scale = l2_norm_clip
if loss_reduction == 'mean':
scale /= tf.cast(batch_size, tf.float32)

def add_noise(g):
return g + tf.random.normal(
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
)

return tf.nest.map_structure(add_noise, clipped_grads)


def generate_model_outputs_using_core_keras_layers(
input_model: tf.keras.Model,
custom_layer_set: Optional[Set[type]] = None, # pylint: disable=g-bare-generic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Any

from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils

Expand Down Expand Up @@ -135,60 +134,6 @@ def test_outputs_are_consistent(
self.assertAllClose(computed_outputs, true_outputs)


class AddAggregateNoise(tf.test.TestCase, parameterized.TestCase):

@parameterized.product(
l2_norm_clip=[3.0, 5.0],
noise_multiplier=[2.0, 4.0],
batch_size=[1, 2, 10],
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
noise_fn_reduction=[None, 'mean', 'sum'],
)
def test_noise_is_computed_correctly(
self,
l2_norm_clip,
noise_multiplier,
batch_size,
model_fn_reduction,
noise_fn_reduction,
):
# Skip invalid combinations.
if model_fn_reduction is None and noise_fn_reduction is None:
return
if model_fn_reduction is not None and noise_fn_reduction is not None:
return
# Make an simple model container for storing the loss.
if model_fn_reduction is not None:
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
linear_model.compile(
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
)
else:
linear_model = None
# The main computation is done on a deterministic dummy vector.
num_units = 100
clipped_grads = [
tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
]
noised_grads = gradient_clipping_utils.add_aggregate_noise(
clipped_grads,
batch_size,
l2_norm_clip,
noise_multiplier,
noise_fn_reduction,
linear_model,
)
# The only measure that varies is the standard deviation of the variation.
scale = (
1.0
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
else 1.0 / batch_size
)
computed_std = np.std(noised_grads[0] - clipped_grads[0])
expected_std = l2_norm_clip * noise_multiplier * scale
self.assertNear(computed_std, expected_std, 0.1 * expected_std)


class GenerateOutputsUsingCoreKerasLayers(
tf.test.TestCase, parameterized.TestCase
):
Expand Down
115 changes: 115 additions & 0 deletions tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2024, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions that help in adding noise to gradients."""

from collections.abc import Sequence
from typing import Literal, Optional

from absl import logging
import tensorflow as tf


def _infer_loss_reduction_type(model: tf.keras.Model):
"""Infers what type of loss reduction is being performed."""
model_loss = model.loss
if isinstance(model_loss, tf.keras.losses.Loss):
return model_loss.reduction
elif isinstance(model.loss, dict):
reductions = set()
compiled_loss = model.compiled_loss
if compiled_loss is None:
raise ValueError('Model must be compiled for adding noise')
new_config_list = compiled_loss.get_config()['losses']
for loss_config in new_config_list:
reductions.add(loss_config['config']['reduction'])
if len(reductions) > 1:
raise ValueError(
'Reductions in models with multiple losses must all be the same'
)
return reductions.pop()
else:
raise ValueError(
'Unsupported type for adding noise: {}'.format(type(model_loss))
)


def add_aggregate_noise(
clipped_grads: list[tf.Tensor],
batch_size: tf.Tensor,
l2_norm_clip: float,
noise_multiplier: float,
loss_reduction: Optional[Literal['mean', 'sum']] = None,
loss_model: Optional[tf.keras.Model] = None,
) -> Sequence[tf.Tensor]:
"""Adds noise to a collection of clipped gradients.
The magnitude of the noise depends on the aggregation strategy of the
input model's loss function.
Args:
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
batch_size: The batch size. Used for normalizing the noise when
`loss_reduction` is 'sum'.
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
loss_reduction: An string description of how the loss is reduced over
examples. Currently supports 'mean' and 'sum'. If `None`, then the
aggregation type must be inferred from `input_model.loss`.
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
strategy from if `loss_reduction` is `None`.
Returns:
A list of tensors containing the clipped gradients, but with the right
amount of Gaussian noise added to them (depending on the reduction
strategy of the loss function).
Raises:
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
they are both not `None`.
"""
if loss_reduction is None and loss_model is None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were `None`.'
)
if loss_reduction is not None and loss_model is not None:
raise ValueError(
'Exactly one of `loss_reduction` and `loss_model` must be populated.'
' Instead, both arguments were not `None`.'
)

if loss_reduction is None and loss_model is not None:
implicit_mean_reductions = [
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]
model_reduction = _infer_loss_reduction_type(loss_model)
loss_reduction = (
'mean' if model_reduction in implicit_mean_reductions else 'sum'
)
if model_reduction == tf.keras.losses.Reduction.AUTO:
logging.info(
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
)

scale = l2_norm_clip
if loss_reduction == 'mean':
scale /= tf.cast(batch_size, tf.float32)

def add_noise(g):
return g + tf.random.normal(
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
)

return tf.nest.map_structure(add_noise, clipped_grads)
Loading

0 comments on commit 614c93f

Please sign in to comment.