Skip to content

Commit

Permalink
Sparsity Preserving DP-SGD in TF Privacy
Browse files Browse the repository at this point in the history
Move get_registry_generator_fn from clip_grads.py to gradient_clipping_utils.py and change return type to dataclass.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 660548431
  • Loading branch information
tensorflower-gardener committed Aug 7, 2024
1 parent d3f527e commit 8294cec
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 44 deletions.
52 changes: 10 additions & 42 deletions tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,43 +32,6 @@
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases


def get_registry_generator_fn(
tape: tf.GradientTape,
layer_registry: lr.LayerRegistry,
num_microbatches: Optional[type_aliases.BatchSize] = None,
):
"""Creates the generator function for `compute_gradient_norms()`."""
if layer_registry is None:
# Needed for backwards compatibility.
registry_generator_fn = None
else:

def registry_generator_fn(layer_instance, args, kwargs):
if layer_instance.trainable_variables:
# Only trainable variables factor into the gradient.
if not layer_registry.is_elem(layer_instance):
raise NotImplementedError(
'Layer %s is not in the registry of known layers that can '
'be used for efficient gradient clipping.'
% layer_instance.__class__.__name__
)
registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, kwargs, tape, num_microbatches
)
return layer_outputs, (
str(id(layer_instance)),
layer_vars,
layer_sqr_norm_fn,
layer_instance.trainable_weights,
)
else:
# Non-trainable layer.
return layer_instance(*args, **kwargs), None

return registry_generator_fn


def _infer_per_example_loss_fn(model: tf.keras.Model):
"""Infer the per-example loss from model config."""

Expand Down Expand Up @@ -190,7 +153,7 @@ def compute_gradient_norms(
are applied prior to clipping.
"""
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
registry_generator_fn = get_registry_generator_fn(
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
tape, layer_registry, num_microbatches
)
# First loop computes the model outputs, summed loss, and generator outputs.
Expand Down Expand Up @@ -241,12 +204,17 @@ def compute_gradient_norms(
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
# occurrence. This is an over-estimate of the actual norm. For more details,
# see the explanation in go/dp-sgd-shared-weights.
for layer_id, v, f, weights_list in filtered_outputs:
for registry_fn_output in filtered_outputs:
if trainable_vars is None or any(
w.ref() in trainable_vars for w in weights_list
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
layer_vars[layer_id].append(v)
layer_sqr_norm_fns[layer_id].append(f)
layer_vars[registry_fn_output.layer_id].append(
registry_fn_output.layer_vars
)
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
registry_fn_output.layer_sqr_norm_fn
)
# Second loop evaluates the squared L2 norm functions and appends the results.
layer_grad_vars = tape.gradient(
summed_loss,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,23 @@
# limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms."""

from collections.abc import Sequence, Set
from typing import Any, Optional
from collections.abc import Callable, Sequence, Set
import dataclasses
from typing import Any, Optional, Tuple

import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases


@dataclasses.dataclass(frozen=True)
class RegistryGeneratorFunctionOutput:
layer_id: str
layer_vars: Optional[Sequence[tf.Variable]]
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
layer_trainable_weights: Optional[Sequence[tf.Variable]]


def has_internal_compute_graph(input_object: Any):
"""Checks if input is a TF model and has a TF internal compute graph."""
return (
Expand All @@ -32,6 +41,63 @@ def has_internal_compute_graph(input_object: Any):
)


def get_registry_generator_fn(
tape: tf.GradientTape,
layer_registry: lr.LayerRegistry,
num_microbatches: Optional[type_aliases.BatchSize] = None,
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
"""Creates the generator function for `model_forward_backward_pass()`.
Args:
tape: The `tf.GradientTape` to use for the gradient computation.
layer_registry: A `dict` of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
of num_microbatches).
Returns:
A function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable variables.
"""
if layer_registry is None:
# Needed for backwards compatibility.
registry_generator_fn = None
else:

def registry_generator_fn(layer_instance, args, kwargs):
if layer_instance.trainable_variables:
# Only trainable variables factor into the gradient.
if not layer_registry.is_elem(layer_instance):
raise NotImplementedError(
'Layer %s is not in the registry of known layers that can '
'be used for efficient gradient clipping.'
% layer_instance.__class__.__name__
)
registry_fn = layer_registry.lookup(layer_instance)
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
layer_instance, args, kwargs, tape, num_microbatches
)
return layer_outputs, RegistryGeneratorFunctionOutput(
layer_id=str(id(layer_instance)),
layer_vars=layer_vars,
layer_sqr_norm_fn=layer_sqr_norm_fn,
layer_trainable_weights=layer_instance.trainable_weights,
)
else:
# Non-trainable layer.
return layer_instance(*args, **kwargs), None

return registry_generator_fn


def model_forward_pass(
input_model: tf.keras.Model,
inputs: type_aliases.PackedTensors,
Expand Down

0 comments on commit 8294cec

Please sign in to comment.