Skip to content

Commit

Permalink
Sparsity Preserving DP-SGD in TF Privacy
Browse files Browse the repository at this point in the history
Add function to merge varname_to_contribution_count_fn maps from different layers.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 660525767
  • Loading branch information
tensorflower-gardener committed Aug 12, 2024
1 parent e42b574 commit 10a44ff
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 8 deletions.
8 changes: 7 additions & 1 deletion tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ py_library(
":common_manip_utils",
":layer_registry",
":type_aliases",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases",
],
)

Expand All @@ -55,7 +57,11 @@ py_test(
python_version = "PY3",
shard_count = 8,
srcs_version = "PY3",
deps = [":gradient_clipping_utils"],
deps = [
":gradient_clipping_utils",
":layer_registry",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
],
)

py_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def compute_gradient_norms(
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=layer_registry,
sparse_noise_layer_registry=None,
num_microbatches=num_microbatches,
)
layer_grad_vars, generator_outputs_list = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _run_model_forward_backward_pass(
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=layer_registry.make_default_layer_registry(),
sparse_noise_layer_registry=None,
num_microbatches=None,
)
layer_grad_vars, registry_fn_outputs_list = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases as sn_type_aliases


@dataclasses.dataclass(frozen=True)
class RegistryGeneratorFunctionOutput:
layer_id: str
layer_vars: Optional[Sequence[tf.Variable]]
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
varname_to_count_contribution_fn: Optional[
dict[str, sn_type_aliases.ContributionCountHistogramFn]
]
layer_trainable_weights: Optional[Sequence[tf.Variable]]


Expand All @@ -46,6 +51,7 @@ def has_internal_compute_graph(input_object: Any):
def get_registry_generator_fn(
tape: tf.GradientTape,
layer_registry: lr.LayerRegistry,
sparse_noise_layer_registry: snlr.LayerRegistry,
num_microbatches: Optional[type_aliases.BatchSize] = None,
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
"""Creates the generator function for `model_forward_backward_pass()`.
Expand All @@ -58,6 +64,10 @@ def get_registry_generator_fn(
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable
sparse_noise_layer_registry: A `LayerRegistry` instance containing functions
that help compute contribution counts for sparse noise. See
`tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
more details.
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
Expand All @@ -83,6 +93,16 @@ def registry_generator_fn(layer_instance, args, kwargs):
'be used for efficient gradient clipping.'
% layer_instance.__class__.__name__
)
varname_to_count_contribution_fn = None
if sparse_noise_layer_registry and sparse_noise_layer_registry.is_elem(
layer_instance
):
count_contribution_registry_fn = sparse_noise_layer_registry.lookup(
layer_instance
)
varname_to_count_contribution_fn = count_contribution_registry_fn(
layer_instance, args, kwargs, num_microbatches
)
registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, kwargs, tape, num_microbatches
Expand All @@ -91,6 +111,7 @@ def registry_generator_fn(layer_instance, args, kwargs):
layer_id=str(id(layer_instance)),
layer_vars=layer_vars,
layer_sqr_norm_fn=layer_sqr_norm_fn,
varname_to_count_contribution_fn=varname_to_count_contribution_fn,
layer_trainable_weights=layer_instance.trainable_weights,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr


# ==============================================================================
Expand Down Expand Up @@ -175,5 +177,92 @@ def test_new_custom_layer_spec(self):
)


class RegistryGeneratorFnTest(tf.test.TestCase, parameterized.TestCase):

def _get_sparse_layer_registry(self):
def count_contribution_fn(_):
return None

def registry_fn(*_):
return {'var': count_contribution_fn}

registry = snlr.LayerRegistry()
registry.insert(tf.keras.layers.Embedding, registry_fn)
return registry, count_contribution_fn

def _get_layer_registry(self):
var = tf.Variable(1.0)
output = tf.ones((1, 1))

def sqr_norm_fn(_):
return None

def registry_fn(*_):
return [var], output, sqr_norm_fn

registry = lr.LayerRegistry()
registry.insert(tf.keras.layers.Embedding, registry_fn)
registry.insert(tf.keras.layers.Dense, registry_fn)
return registry, var, output, sqr_norm_fn

def test_registry_generator_fn(self):
inputs = tf.constant([[0, 1]])
model = tf.keras.Sequential([
tf.keras.layers.Embedding(10, 1),
tf.keras.layers.Dense(1),
])

sparse_layer_registry, count_contribution_fn = (
self._get_sparse_layer_registry()
)
layer_registry, var, output, sqr_norm_fn = self._get_layer_registry()
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape=tf.GradientTape(),
layer_registry=layer_registry,
sparse_noise_layer_registry=sparse_layer_registry,
num_microbatches=None,
)
embedding_layer = model.layers[0]
out, embedding_registry_generator_fn_output = registry_generator_fn(
embedding_layer,
[inputs],
{},
)
expected_embedding_registry_generator_fn_output = (
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id=str(id(embedding_layer)),
layer_vars=[var],
layer_sqr_norm_fn=sqr_norm_fn,
varname_to_count_contribution_fn={'var': count_contribution_fn},
layer_trainable_weights=embedding_layer.trainable_weights,
)
)
self.assertEqual(
embedding_registry_generator_fn_output,
expected_embedding_registry_generator_fn_output,
)
self.assertEqual(out, output)
dense_layer = model.layers[1]
out, dense_registry_generator_fn_output = registry_generator_fn(
dense_layer,
[inputs],
{},
)
expected_dense_registry_generator_fn_output = (
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id=str(id(dense_layer)),
layer_vars=[var],
layer_sqr_norm_fn=sqr_norm_fn,
varname_to_count_contribution_fn=None,
layer_trainable_weights=dense_layer.trainable_weights,
)
)
self.assertEqual(
dense_registry_generator_fn_output,
expected_dense_registry_generator_fn_output,
)
self.assertEqual(out, output)


if __name__ == '__main__':
tf.test.main()
1 change: 1 addition & 0 deletions tensorflow_privacy/privacy/keras_models/dp_keras_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def train_step(self, data):
gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=self._layer_registry,
sparse_noise_layer_registry=None,
num_microbatches=num_microbatches,
)
)
Expand Down
9 changes: 8 additions & 1 deletion tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,19 @@ licenses(["notice"])
py_library(
name = "sparse_noise_utils",
srcs = ["sparse_noise_utils.py"],
deps = [
":type_aliases",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
],
)

py_test(
name = "sparse_noise_utils_test",
srcs = ["sparse_noise_utils_test.py"],
deps = [":sparse_noise_utils"],
deps = [
":sparse_noise_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
],
)

py_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
"""

import collections
from typing import Mapping, Optional, Sequence

from scipy import stats
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
import tensorflow_probability as tfp


Expand Down Expand Up @@ -288,15 +291,60 @@ def add_sparse_gradient_noise(
)


def extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list: Sequence[
gradient_clipping_utils.RegistryGeneratorFunctionOutput
],
trainable_vars: Sequence[tf.Variable],
) -> dict[str, type_aliases.ContributionCountHistogramFn]:
"""Extracts a map of contribution count fns from generator outputs.
Args:
registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput`
instances returned by
`gradient_clipping_utils.model_forward_backward_pass`.
trainable_vars: A list of trainable variables.
Returns:
A `dict` from varname to contribution counts functions
"""
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])

varname_to_contribution_counts_fns = collections.defaultdict(list)
for registry_fn_output in registry_fn_outputs_list:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
if registry_fn_output.varname_to_count_contribution_fn is not None:
duplicate_varnames = set(
registry_fn_output.varname_to_count_contribution_fn.keys()
) & set(varname_to_contribution_counts_fns.keys())
if duplicate_varnames:
raise ValueError(
'Duplicate varnames: {duplicate_varnames} found in contribution'
' counts functions.'
)
varname_to_contribution_counts_fns.update(
registry_fn_output.varname_to_count_contribution_fn
)
return varname_to_contribution_counts_fns


def get_contribution_counts(
trainable_vars: list[tf.Variable],
grads: list[tf.Tensor],
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor],
) -> list[tf.Tensor | None]:
trainable_vars: Sequence[tf.Variable],
grads: Sequence[tf.Tensor],
varname_to_contribution_counts_fns: Mapping[
str, type_aliases.ContributionCountHistogramFn
],
) -> list[type_aliases.ContributionCountHistogram | None]:
"""Gets the contribution counts for each variable in the Model.
Args:
trainable_vars: A list of the trainable variables in the Model.
trainable_vars: A list of trainable variables.
grads: A corresponding list of gradients for each trainable variable.
varname_to_contribution_counts_fns: A mapping from variable name to a list
of functions to get the contribution counts for that variable.
Expand Down
Loading

0 comments on commit 10a44ff

Please sign in to comment.