Skip to content

Commit

Permalink
Add support for Tensor learning rates and gradients with mixed types.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671726026
  • Loading branch information
ZacharyGarrett authored and copybara-github committed Sep 7, 2024
1 parent c3e36d7 commit 4c70837
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 41 deletions.
6 changes: 4 additions & 2 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ and this project adheres to

### Fixed

* A bug where `tff.learning.optimizers.build_adafactor(...)` would update its
step counter twice upon every invocation of `.next()`.
* A bug where `tff.learning.optimizers.build_adafactor` would update its step
counter twice upon every invocation of `.next()`.
* A bug where tensor learning rates for `tff.learning.optimizers.build_sgdm`
would fail with mixed dtype gradients.

### Removed

Expand Down
22 changes: 14 additions & 8 deletions tensorflow_federated/python/learning/optimizers/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_HPARAMS_KEYS = [optimizer.LEARNING_RATE_KEY, _EPSILON_KEY]

State = TypeVar('State', bound=collections.OrderedDict[str, Any])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any])


class _Adagrad(optimizer.Optimizer[State, optimizer.Weights, Hparams]):
Expand All @@ -40,31 +40,35 @@ def __init__(
epsilon: optimizer.Float = 1e-7,
):
"""Initializes SGD optimizer."""
if learning_rate < 0.0:
if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0:
raise ValueError(
f'Adagrad `learning_rate` must be nonnegative, found {learning_rate}.'
)
if initial_preconditioner_value < 0.0:
if (
not tf.is_symbolic_tensor(initial_preconditioner_value)
and initial_preconditioner_value < 0.0
):
raise ValueError(
'Adagrad `initial_preconditioner_value` must be nonnegative, found '
f'{initial_preconditioner_value}.'
)
if epsilon < 0.0:
if not tf.is_symbolic_tensor(epsilon) and epsilon < 0.0:
raise ValueError(f'Adagrad epsilon must be nonnegative, found {epsilon}.')
self._lr = learning_rate
self._initial_precond = initial_preconditioner_value
self._epsilon = epsilon

def initialize(self, specs: Any) -> State:
initial_preconditioner = tf.nest.map_structure(
lambda s: tf.ones(s.shape, s.dtype) * self._initial_precond, specs
lambda s: tf.ones(s.shape, s.dtype)
* tf.cast(self._initial_precond, s.dtype),
specs,
)
state = collections.OrderedDict([
return collections.OrderedDict([
(optimizer.LEARNING_RATE_KEY, self._lr),
(_EPSILON_KEY, self._epsilon),
(_PRECONDITIONER_KEY, initial_preconditioner),
])
return state

def next(
self, state: State, weights: optimizer.Weights, gradients: Any
Expand All @@ -82,7 +86,9 @@ def _adagrad_update(w, p, g):
if g is None:
return w, p
p = p + tf.math.square(g)
w = w - lr * g / tf.math.sqrt(p + epsilon)
w = w - tf.cast(lr, g.dtype) * g / tf.math.sqrt(
p + tf.cast(epsilon, p.dtype)
)
return w, p

updated_weights, updated_preconditioner = nest_utils.map_at_leaves(
Expand Down
20 changes: 18 additions & 2 deletions tensorflow_federated/python/learning/optimizers/adagrad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def random_vector():
genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec
]

intial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight]
initial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight]
gradients = [random_vector() for _ in range(steps)]
tff_optimizer_fn = lambda: adagrad.build_adagrad(0.01)
keras_optimizer_fn = lambda: tf.keras.optimizers.Adagrad(0.01)
Expand Down Expand Up @@ -227,6 +227,22 @@ def test_set_get_hparams_is_no_op(self, spec):
updated_state = optimizer.set_hparams(state, hparams)
self.assertEqual(state, updated_state)

def test_lr_with_different_weight_dtypes(self):
weights = (
tf.constant([0.1], dtype=tf.float32),
tf.constant(1.0, dtype=tf.float64),
tf.constant([10.0, 10.0], dtype=tf.bfloat16),
)
adagrad_optimizer = adagrad.build_adagrad(
learning_rate=tf.constant(0.1, dtype=tf.float32),
initial_preconditioner_value=tf.constant(0.1, dtype=tf.float32),
epsilon=tf.constant(0.1, dtype=tf.float64),
)
state = adagrad_optimizer.initialize(weights)
adagrad_optimizer.next(
state, weights, tf.nest.map_structure(tf.zeros_like, weights)
)


if __name__ == '__main__':
tf.test.main()
25 changes: 13 additions & 12 deletions tensorflow_federated/python/learning/optimizers/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
]

State = TypeVar('State', bound=collections.OrderedDict[str, Any])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any])


class _Adam(optimizer.Optimizer[State, optimizer.Weights, Hparams]):
Expand All @@ -50,19 +50,19 @@ def __init__(
epsilon: optimizer.Float = 1e-7,
):
"""Initializes Adam optimizer."""
if learning_rate < 0.0:
if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0:
raise ValueError(
f'Adam `learning_rate` must be nonnegative, found {learning_rate}.'
)
if beta_1 < 0.0 or beta_1 > 1.0:
if not tf.is_symbolic_tensor(beta_1) and (beta_1 < 0.0 or beta_1 > 1.0):
raise ValueError(
f'Adam `beta_1` must be in the range [0.0, 1.0], found {beta_1}.'
)
if beta_2 < 0.0 or beta_2 > 1.0:
if not tf.is_symbolic_tensor(beta_2) and (beta_2 < 0.0 or beta_2 > 1.0):
raise ValueError(
f'Adam `beta_2` must be in the range [0.0, 1.0], found {beta_2}.'
)
if epsilon < 0.0:
if not tf.is_symbolic_tensor(epsilon) and epsilon < 0.0:
raise ValueError(f'Adam `epsilon` must be nonnegative, found {epsilon}.')
self._lr = learning_rate
self._beta_1 = beta_1
Expand All @@ -76,7 +76,7 @@ def initialize(self, specs: Any) -> State:
initial_preconditioner = tf.nest.map_structure(
lambda s: tf.zeros(s.shape, s.dtype), specs
)
state = collections.OrderedDict([
return collections.OrderedDict([
(optimizer.LEARNING_RATE_KEY, self._lr),
(_BETA_1_KEY, self._beta_1),
(_BETA_2_KEY, self._beta_2),
Expand All @@ -85,7 +85,6 @@ def initialize(self, specs: Any) -> State:
(_ACCUMULATOR_KEY, initial_accumulator),
(_PRECONDITIONER_KEY, initial_preconditioner),
])
return state

def next(
self, state: State, weights: optimizer.Weights, gradients: Any
Expand All @@ -105,16 +104,18 @@ def next(
)
normalized_lr = (
lr
* tf.math.sqrt((1 - tf.math.pow(beta_2, tf.cast(step, tf.float32))))
/ (1 - tf.math.pow(beta_1, tf.cast(step, tf.float32)))
* tf.math.sqrt((1.0 - tf.math.pow(beta_2, step)))
/ (1.0 - tf.math.pow(beta_1, step))
)

def _adam_update(w, a, p, g):
if g is None:
return w, a, p
a = a + (g - a) * (1 - beta_1)
p = p + (tf.math.square(g) - p) * (1 - beta_2)
w = w - normalized_lr * a / (tf.math.sqrt(p) + epsilon)
a = a + (g - a) * (1 - tf.cast(beta_1, a.dtype))
p = p + (tf.math.square(g) - p) * (1 - tf.cast(beta_2, p.dtype))
w = w - tf.cast(normalized_lr, a.dtype) * a / (
tf.math.sqrt(p) + tf.cast(epsilon, p.dtype)
)
return w, a, p

updated_weights, updated_accumulator, updated_preconditioner = (
Expand Down
25 changes: 20 additions & 5 deletions tensorflow_federated/python/learning/optimizers/adam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def test_math(self):
for _ in range(4):
state, weights = optimizer.next(state, weights, gradients)
history.append(weights)
self.assertAllClose(
[[1.0], [0.9000007], [0.8000017], [0.700002], [0.600003]], history
)
self.assertAllClose([[1.0], [0.9], [0.8], [0.7], [0.6]], history)

@parameterized.named_parameters(
('scalar_spec', _SCALAR_SPEC),
Expand Down Expand Up @@ -142,8 +140,8 @@ def random_vector():
genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec
]

intial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight]
initial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight]
gradients = [random_vector() for _ in range(steps)]
tff_optimizer_fn = lambda: adam.build_adam(0.01, 0.9, 0.999)
keras_optimizer_fn = lambda: tf.keras.optimizers.Adam(0.01, 0.9, 0.999)
Expand Down Expand Up @@ -225,6 +223,23 @@ def test_set_get_hparams_is_no_op(self, spec):
updated_state = optimizer.set_hparams(state, hparams)
self.assertEqual(state, updated_state)

def test_lr_with_different_weight_dtypes(self):
weights = (
tf.constant([0.1], dtype=tf.float32),
tf.constant(1.0, dtype=tf.float64),
tf.constant([10.0, 10.0], dtype=tf.bfloat16),
)
adam_optimizer = adam.build_adam(
learning_rate=tf.constant(0.1, dtype=tf.float32),
beta_1=tf.constant(0.1, dtype=tf.float32),
beta_2=tf.constant(0.1, dtype=tf.float32),
epsilon=tf.constant(0.1, dtype=tf.float64),
)
state = adam_optimizer.initialize(weights)
adam_optimizer.next(
state, weights, tf.nest.map_structure(tf.zeros_like, weights)
)


if __name__ == '__main__':
tf.test.main()
15 changes: 6 additions & 9 deletions tensorflow_federated/python/learning/optimizers/sgdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_ACCUMULATOR_KEY = 'accumulator'

State = TypeVar('State', bound=collections.OrderedDict[str, Any])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any])


class _SGD(optimizer.Optimizer[State, optimizer.Weights, Hparams]):
Expand All @@ -38,14 +38,16 @@ def __init__(
momentum: Optional[optimizer.Float] = None,
):
"""Initializes SGD optimizer."""
if learning_rate < 0.0:
if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0:
raise ValueError(
f'SGD `learning_rate` must be nonnegative, found {learning_rate}.'
)
if momentum:
# We should only track momentum as a hparam in the case that it is both
# specified and nonzero.
if momentum < 0.0 or momentum > 1.0:
if not tf.is_symbolic_tensor(momentum) and (
momentum < 0.0 or momentum > 1.0
):
raise ValueError(
'SGD `momentum` must be `None` or in the range [0, 1], found '
f'{momentum}.'
Expand Down Expand Up @@ -77,7 +79,7 @@ def next(
def _sgd_update(w, g):
if g is None:
return w
return w - lr * g
return w - tf.cast(lr, dtype=g.dtype) * g

updated_weights = nest_utils.map_at_leaves(
_sgd_update, weights, gradients
Expand Down Expand Up @@ -111,11 +113,6 @@ def get_hparams(self, state: State) -> Hparams:
return collections.OrderedDict([(k, state[k]) for k in self._hparams_keys])

def set_hparams(self, state: State, hparams: Hparams) -> State:
# TODO: b/245962555 - Find an alternative to `update_struct` if it
# interferes with typing guarantees.
# We use `tff.structure.update_struct` (rather than something like
# `copy.deepcopy`) to ensure that this can be called within a
# `tff.Computation`.
return structure.update_struct(state, **hparams)


Expand Down
20 changes: 17 additions & 3 deletions tensorflow_federated/python/learning/optimizers/sgdm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_get_hparams_momentum(self, momentum_value):
optimizer = sgdm.build_sgdm(0.01, momentum=momentum_value)
state = optimizer.initialize(_SCALAR_SPEC)
hparams = optimizer.get_hparams(state)
# Whether we specify None momentum or momentum 0.0, we shouldnt track the
# Whether we specify None momentum or momentum 0.0, we shouldn't track the
# extra accumulator state. The implementation of next checks for the
# presence or absence of momentum key--it should not be there in either
# case.
Expand Down Expand Up @@ -177,8 +177,8 @@ def random_vector():
genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec
]

intial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight]
initial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight]
gradients = [random_vector() for _ in range(steps)]
tff_optimizer_fn = lambda: sgdm.build_sgdm(learning_rate, momentum)

Expand Down Expand Up @@ -306,6 +306,20 @@ def test_set_get_hparams_is_no_op_with_momentum(self, spec):
updated_state = optimizer.set_hparams(state, hparams)
self.assertEqual(state, updated_state)

def test_lr_with_different_weight_dtypes(self):
weights = (
tf.constant([0.1], dtype=tf.float32),
tf.constant(1.0, dtype=tf.float64),
tf.constant([10.0, 10.0], dtype=tf.bfloat16),
)
sgdm_optimizer = sgdm.build_sgdm(
learning_rate=tf.constant(0.1, dtype=tf.float32)
)
state = sgdm_optimizer.initialize(weights)
sgdm_optimizer.next(
state, weights, tf.nest.map_structure(tf.zeros_like, weights)
)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 4c70837

Please sign in to comment.