forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimizer.py
611 lines (565 loc) · 25.8 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A module for the main curvature optimizer class."""
from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Tuple, Union
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jnr
from kfac_ferminet_alpha import estimator
from kfac_ferminet_alpha import tag_graph_matcher as tgm
from kfac_ferminet_alpha import utils
ScheduleType = Callable[[jnp.ndarray], Optional[jnp.ndarray]]
Parameters = Any
Batch = Any
FuncState = Any
State = Mapping[str, Any]
@utils.Stateful.infer_class_state
class Optimizer(utils.Stateful):
"""The default optimizer class."""
velocities: Parameters
estimator: estimator.CurvatureEstimator
step_counter: jnp.ndarray
def __init__(
self,
value_and_grad_func,
l2_reg: Union[float, jnp.ndarray],
value_func_has_aux: bool = False,
value_func_has_state: bool = False,
value_func_has_rng: bool = False,
learning_rate_schedule: Optional[ScheduleType] = None,
momentum_schedule: Optional[ScheduleType] = None,
damping_schedule: Optional[ScheduleType] = None,
min_damping: Union[float, jnp.ndarray] = 1e-8,
max_damping: Union[float, jnp.ndarray] = jnp.inf,
norm_constraint: Optional[Union[float, jnp.ndarray]] = None,
num_burnin_steps: int = 10,
estimation_mode: str = "fisher_gradients",
curvature_ema: Union[float, jnp.ndarray] = 0.95,
inverse_update_period: int = 5,
register_only_generic: bool = False,
layer_tag_to_block_cls: Optional[estimator.TagMapping] = None,
patterns_to_skip: Sequence[str] = (),
donate_parameters: bool = False,
donate_optimizer_state: bool = False,
donate_batch_inputs: bool = False,
donate_func_state: bool = False,
batch_process_func: Optional[Callable[[Any], Any]] = None,
multi_device: bool = False,
use_jax_cond: bool = True,
debug: bool = False,
pmap_axis_name="kfac_axis",
):
"""Initializes the K-FAC optimizer with the given settings.
Args:
value_and_grad_func: Python callable. The function should return the value
of the loss to be optimized and its gradients. If the argument
`value_func_has_aux` is `False` then the interface should be: loss,
loss_grads = value_and_grad_func(params, batch)
If `value_func_has_aux` is `True` then the interface should be: (loss,
aux), loss_grads = value_and_grad_func(params, batch)
l2_reg: Scalar. Set this value to tell the optimizer what L2
regularization coefficient you are using (if any). Note the coefficient
appears in the regularizer as coeff / 2 * sum(param**2). Note that the
user is still responsible for adding regularization to the loss.
value_func_has_aux: Boolean. Specifies whether the provided callable
`value_and_grad_func` returns the loss value only, or also some
auxiliary data. (Default: False)
value_func_has_state: Boolean. Specifies whether the provided callable
`value_and_grad_func` has a persistent state that is inputed and
it also outputs an update version of it. (Default: False)
value_func_has_rng: Boolean. Specifies whether the provided callable
`value_and_grad_func` additionally takes as input an rng key.
(Default: False)
learning_rate_schedule: Callable. A schedule for the learning rate. This
should take as input the current step number and return a single
`jnp.ndarray` that represents the learning rate. (Default: None)
momentum_schedule: Callable. A schedule for the momentum. This should take
as input the current step number and return a single `jnp.ndarray`
that represents the momentum. (Default: None)
damping_schedule: Callable. A schedule for the damping. This should take
as input the current step number and return a single `jnp.ndarray`
that represents the learning rate. (Default: None)
min_damping: Scalar. Minimum value the damping parameter can take. Note
that the default value of 1e-8 is quite arbitrary, and you may have to
adjust this up or down for your particular problem. If you are using a
non-zero value of l2_reg you *may* be able to set this to
zero. (Default: 1e-8)
max_damping: Scalar. Maximum value the damping parameter can take.
(Default: Infinity)
norm_constraint: Scalar. If specified, the update is scaled down so that
its approximate squared Fisher norm `v^T F v` is at most the specified
value.(Note that here `F` is the approximate curvature matrix, not the
exact.) (Default: None)
num_burnin_steps: Int. At the start of optimization, e.g. the first step,
before performing the actual step the optimizer will perform this many
times updates to the curvature approximation without updating the
actual parameters. (Default: 10)
estimation_mode: String. The type of estimator to use for the curvature
matrix. Can be one of: * fisher_empirical * fisher_exact *
fisher_gradients * fisher_curvature_prop * ggn_exact *
ggn_curvature_prop See the doc-string for CurvatureEstimator (in
estimator.py) for a more
detailed description of these options. (Default: 'fisher_gradients').
curvature_ema: The decay factor used when calculating the covariance
estimate moving averages. (Default: 0.95)
inverse_update_period: Int. The number of steps in between updating the
the computation of the inverse curvature approximation. (Default: 5)
register_only_generic: Boolean. Whether when running the auto-tagger to
register only generic parameters, or allow it to use the graph matcher
to automatically pick up any kind of layer tags. (Default: False)
layer_tag_to_block_cls: Dictionary. A mapping from layer tags to block
classes which to override the default choices of block approximation for
that specific tag. See the doc-string for CurvatureEstimator (in
estimator.py) for a more detailed description of this.
patterns_to_skip: Tuple. A list of any patterns that should be skipped by
the graph matcher when auto-tagging.
donate_parameters: Boolean. Whether to use jax's `donate_argnums` to
donate the parameter values of each call to `step`. Note that this
implies that you will not be able to access the old parameter values'
buffers after calling into `step`.
donate_optimizer_state: Boolean. Whether to use jax's `donate_argnums` to
donate the optimizer state of each call to `step`. Note that this
implies that you will not be able to access the old optimizer state
values' buffers after calling into `step`.
donate_batch_inputs: Boolean. Whether to use jax's `donate_argnums` to
donate the batch values of each call to `step`. Note that this implies
that you will not be able to access the old batch values' buffers after
calling into `step`.
donate_func_state: Boolean. Whether to use jax's `donate_argnums` to
donate the persistent function state of each call to `step`. Note that
this implies that you will not be able to access the old function state
values' buffers after calling into `step`.
batch_process_func: Callable. A function which to be called on each batch
before feeding to the KFAC on device. This could be useful for specific
device input optimizations.
multi_device: Boolean. Whether to use `pmap` and run the optimizer on
multiple devices. (Default: False)
use_jax_cond: Not used for the moment.
debug: Boolean. If non of the step or init functions would be jitted. Note
that this also overrides `multi_device` and prevents using `pmap`.
(Default: False)
pmap_axis_name: String. The name of the `pmap` axis to use when
`multi_device` is set to True. (Default: curvature_axis)
"""
super().__init__()
self.value_and_grad_func = value_and_grad_func
self.value_func_has_aux = value_func_has_aux
self.value_func_has_state = value_func_has_state
self.value_func_has_rng = value_func_has_rng
self.value_func = utils.convert_value_and_grad_to_value_func(
value_and_grad_func, has_aux=value_func_has_aux)
self.l2_reg = l2_reg
self.learning_rate_schedule = learning_rate_schedule
if momentum_schedule is not None:
def schedule_with_first_step_zero(global_step: jnp.ndarray):
value = momentum_schedule(global_step)
check = jnp.equal(global_step, 0)
return check * jnp.zeros_like(value) + (1 - check) * value
self.momentum_schedule = schedule_with_first_step_zero
else:
self.momentum_schedule = None
self.damping_schedule = damping_schedule
self.min_damping = min_damping
self.max_damping = max_damping
self.norm_constraint = norm_constraint
self.num_burnin_steps = num_burnin_steps
self.estimation_mode = estimation_mode
self.curvature_ema = curvature_ema
self.inverse_update_period = inverse_update_period
self.register_only_generic = register_only_generic
self.layer_tag_to_block_cls = layer_tag_to_block_cls
self.patterns_to_skip = patterns_to_skip
self.donate_parameters = donate_parameters
self.donate_optimizer_state = donate_optimizer_state
self.donate_batch_inputs = donate_batch_inputs
self.donate_func_state = donate_func_state
self.batch_process_func = batch_process_func or (lambda x: x)
self.multi_device = multi_device
self.use_jax_cond = use_jax_cond
self.debug = debug
self.pmap_axis_name = pmap_axis_name if multi_device else None
self._rng_split = utils.p_split if multi_device else jnr.split
# Attributes filled in during self.init()
self.finalized = False
self.tagged_func = None
self.flat_params_shapes = None
self.params_treedef = None
# Special attributes related to jitting/pmap
self._jit_init = None
self._jit_burnin = None
self._jit_step = None
def finalize(
self,
params: Parameters,
rng: jnp.ndarray,
batch: Batch,
func_state: Optional[FuncState] = None,
) -> None:
"""Finalizes the optimizer by tracing the model function with the params and batch."""
if self.finalized:
raise ValueError("Optimizer has already been finalized.")
if self.multi_device:
# We assume that the parameters and batch are replicated, while tracing
# must happen with parameters for a single device call
params, rng, batch = jax.tree_map(lambda x: x[0], (params, rng, batch))
if func_state is not None:
func_state = jax.tree_map(lambda x: x[0], func_state)
batch = self.batch_process_func(batch)
# These are all tracing operations and we can run them with abstract values
func_args = utils.make_func_args(params, func_state, rng, batch,
self.value_func_has_state,
self.value_func_has_rng)
# Run all tracing with abstract values so no computation is done
flat_params, self.params_treedef = jax.tree_flatten(params)
self.flat_params_shapes = tuple(p.shape for p in flat_params)
self.tagged_func = tgm.auto_register_tags(
func=self.value_func,
func_args=func_args,
params_index=0,
register_only_generic=self.register_only_generic,
patterns_to_skip=self.patterns_to_skip)
self.estimator = estimator.CurvatureEstimator(
self.tagged_func,
func_args,
self.l2_reg,
self.estimation_mode,
layer_tag_to_block_cls=self.layer_tag_to_block_cls)
# Arguments: params, opt_state, rng, batch, func_state
donate_argnums = []
if self.donate_parameters:
donate_argnums.append(0)
if self.donate_optimizer_state:
donate_argnums.append(1)
if self.donate_batch_inputs:
donate_argnums.append(3)
if self.donate_func_state and self.value_func_has_state:
donate_argnums.append(4)
donate_argnums = tuple(donate_argnums)
if self.debug:
self._jit_init = self._init
self._jit_burnin = self._burnin
self._jit_step = self._step
elif self.multi_device:
self._jit_init = jax.pmap(
self._init, axis_name=self.pmap_axis_name, donate_argnums=[0])
# batch size is static argnum and is at index 5
self._jit_burnin = jax.pmap(
self._burnin,
axis_name=self.pmap_axis_name,
static_broadcasted_argnums=[5])
self._jit_step = jax.pmap(
self._step,
axis_name=self.pmap_axis_name,
donate_argnums=donate_argnums,
static_broadcasted_argnums=[5])
else:
self._jit_init = jax.jit(self._init, donate_argnums=[0])
# batch size is static argnum and is at index 5
self._jit_burnin = jax.jit(self._burnin, static_argnums=[5])
self._jit_step = jax.jit(
self._step, donate_argnums=donate_argnums, static_argnums=[5])
self.finalized = True
def _init(self, rng: jnp.ndarray) -> State:
"""This is the non-jitted version of initializing the state."""
flat_velocities = [jnp.zeros(shape) for shape in self.flat_params_shapes]
return dict(
velocities=jax.tree_unflatten(self.params_treedef, flat_velocities),
estimator=self.estimator.init(rng, None),
step_counter=jnp.asarray(0))
def verify_args_and_get_step_counter(
self,
params: Parameters,
state: State,
rng: jnp.ndarray,
data_iterator: Iterator[Batch],
func_state: Optional[FuncState] = None,
learning_rate: Optional[jnp.ndarray] = None,
momentum: Optional[jnp.ndarray] = None,
damping: Optional[jnp.ndarray] = None,
global_step_int: Optional[int] = None,
) -> int:
"""Verifies that the arguments passed to `Optimizer.step` are correct."""
if not self.finalized:
rng, rng_finalize = self._rng_split(rng)
self.finalize(params, rng_finalize, next(data_iterator), func_state)
# Verify correct arguments invocation
if self.learning_rate_schedule is not None and learning_rate is not None:
raise ValueError("When you have passed a `learning_rate_schedule` you "
"should not pass a value to the step function.")
if self.momentum_schedule is not None and momentum is not None:
raise ValueError("When you have passed a `momentum_schedule` you should "
"not pass a value to the step function.")
if self.damping_schedule is not None and damping is not None:
raise ValueError("When you have passed a `damping_schedule` you should "
"not pass a value to the step function.")
# Do a bunrnin on the first iteration
if global_step_int is None:
if self.multi_device:
return int(utils.get_first(state["step_counter"]))
else:
return int(state["step_counter"])
return global_step_int
def _burnin(
self,
params: Parameters,
state: State,
rng: jnp.ndarray,
batch: Batch,
func_state: Optional[FuncState],
batch_size: Optional[int],
) -> Tuple[State, Optional[FuncState]]:
"""This is the non-jitted version of a single burnin step."""
self.set_state(state)
batch = self.batch_process_func(batch)
rng, func_rng = jnr.split(rng) if self.value_func_has_rng else (rng, None)
func_args = utils.make_func_args(params, func_state, func_rng, batch,
self.value_func_has_state,
self.value_func_has_rng)
# Compute batch size
if batch_size is None:
batch_size = jax.tree_flatten(batch)[0][0].shape[0]
# Update curvature estimate
ema_old, ema_new = 1.0, 1.0 / self.num_burnin_steps
self.estimator.update_curvature_matrix_estimate(ema_old, ema_new,
batch_size, rng, func_args,
self.pmap_axis_name)
if func_state is not None:
out, _ = self.value_and_grad_func(*func_args)
_, func_state, _ = utils.extract_func_outputs(out,
self.value_func_has_aux,
self.value_func_has_state)
return self.pop_state(), func_state
def _step(
self,
params: Parameters,
state: State,
rng: jnp.ndarray,
batch: Batch,
func_state: Optional[FuncState],
batch_size: Optional[int],
learning_rate: Optional[jnp.ndarray],
momentum: Optional[jnp.ndarray],
damping: Optional[jnp.ndarray],
) -> Union[Tuple[Parameters, State, FuncState, Mapping[str, jnp.ndarray]],
Tuple[Parameters, State, Mapping[str, jnp.ndarray]]]:
"""This is the non-jitted version of a single step."""
# Unpack and set the state
self.set_state(state)
if damping is not None:
assert self.estimator.damping is None
self.estimator.damping = damping
else:
assert self.estimator.damping is not None
# Preprocess the batch and construct correctly the function arguments
batch = self.batch_process_func(batch)
rng, func_rng = jnr.split(rng) if self.value_func_has_rng else (rng, None)
func_args = utils.make_func_args(params, func_state, func_rng, batch,
self.value_func_has_state,
self.value_func_has_rng)
# Compute the batch size
if batch_size is None:
batch_size = jax.tree_flatten(batch)[0][0].shape[0]
# Compute schedules if applicable
if self.learning_rate_schedule is not None:
assert learning_rate is None
learning_rate = self.learning_rate_schedule(self.step_counter)
else:
assert learning_rate is not None
if self.momentum_schedule is not None:
assert momentum is None
momentum = self.momentum_schedule(self.step_counter)
else:
assert momentum is not None
if self.damping_schedule is not None:
assert damping is None
damping = self.damping_schedule(self.step_counter)
else:
assert damping is not None
# Compute current loss and gradients
out, grads = self.value_and_grad_func(*func_args)
loss, new_func_state, aux = utils.extract_func_outputs(
out, self.value_func_has_aux, self.value_func_has_state)
# Sync loss and grads
loss, grads = utils.pmean_if_pmap((loss, grads), self.pmap_axis_name)
# Update curvature estimate
self.estimator.update_curvature_matrix_estimate(
self.curvature_ema,
1.0,
batch_size,
rng,
func_args,
self.pmap_axis_name,
)
# Optionally update the inverse estimate
self.estimator.set_state(
lax.cond(
self.step_counter % self.inverse_update_period == 0,
lambda s: self.estimator.update_curvature_estimate_inverse( # pylint: disable=g-long-lambda
self.pmap_axis_name, s),
lambda s: s,
self.estimator.pop_state()))
# Compute proposed directions
vectors = self.propose_directions(
grads,
self.velocities,
learning_rate,
momentum,
)
# The learning rate is defined as the negative of the coefficient by which
# we multiply the gradients, while the momentum is the coefficient by
# which we multiply the velocities.
neg_learning_rate = -learning_rate # pytype: disable=unsupported-operands # trace-all-classes
# Compute the coefficients of the update vectors
assert neg_learning_rate is not None and momentum is not None
coefficients = (neg_learning_rate, momentum)
# Update velocities and compute new delta
self.velocities, delta = self.velocities_and_delta(
self.velocities,
vectors,
coefficients,
)
# Update parameters: params = params + delta
params = jax.tree_map(jnp.add, params, delta)
# Optionally compute the reduction ratio and update the damping
self.estimator.damping = None
rho = jnp.nan
# Statistics with useful information
stats = dict()
stats["step"] = self.step_counter
stats["loss"] = loss
stats["learning_rate"] = -coefficients[0]
stats["momentum"] = coefficients[1]
stats["damping"] = damping
stats["rho"] = rho
if self.value_func_has_aux:
stats["aux"] = aux
self.step_counter = self.step_counter + 1
if self.value_func_has_state:
return params, self.pop_state(), new_func_state, stats
else:
assert new_func_state is None
return params, self.pop_state(), stats
def init(
self,
params: Parameters,
rng: jnp.ndarray,
batch: Batch,
func_state: Optional[FuncState] = None,
) -> State:
"""Initializes the optimizer and returns the appropriate optimizer state."""
if not self.finalized:
self.finalize(params, rng, batch, func_state)
return self._jit_init(rng)
def step(
self,
params: Parameters,
state: Mapping[str, Any],
rng: jnp.ndarray,
data_iterator: Iterator[Any],
func_state: Any = None,
learning_rate: Optional[jnp.ndarray] = None,
momentum: Optional[jnp.ndarray] = None,
damping: Optional[jnp.ndarray] = None,
batch_size: Optional[int] = None,
global_step_int: Optional[int] = None,
) -> Union[Tuple[Parameters, State, FuncState, Mapping[str, jnp.ndarray]],
Tuple[Parameters, State, Mapping[str, jnp.ndarray]]]:
"""Performs a single update step using the optimizer.
Args:
params: The parameters of the model.
state: The state of the optimizer.
rng: A Jax PRNG key.
data_iterator: An iterator that returns a batch of data.
func_state: Any function state that gets passed in and returned.
learning_rate: This must be provided when
`use_adaptive_learning_rate=False` and `learning_rate_schedule=None`.
momentum: This must be provided when
`use_adaptive_momentum=False` and `momentum_schedule=None`.
damping: This must be provided when
`use_adaptive_damping=False` and `damping_schedule=None`.
batch_size: The batch size to use for KFAC. The default behaviour when it
is None is to use the leading dimension of the first data array.
global_step_int: The global step as a python int. Note that this must
match the step inte rnal to the optimizer that is part of its state.
Returns:
(params, state, stats)
where:
params: The updated model parameters.
state: The updated optimizer state.
stats: A dictionary of key statistics provided to be logged.
"""
step_counter_int = self.verify_args_and_get_step_counter(
params=params,
state=state,
rng=rng,
data_iterator=data_iterator,
func_state=func_state,
learning_rate=learning_rate,
momentum=momentum,
damping=damping,
global_step_int=global_step_int)
if step_counter_int == 0:
for _ in range(self.num_burnin_steps):
rng, rng_burn = self._rng_split(rng)
batch = next(data_iterator)
state, func_state = self._jit_burnin(params, state, rng_burn, batch,
func_state, batch_size)
# On the first step we always treat the momentum as 0.0
if self.momentum_schedule is None:
momentum = jnp.zeros([])
if self.multi_device:
momentum = utils.replicate_all_local_devices(momentum)
batch = next(data_iterator)
return self._jit_step(params, state, rng, batch, func_state, batch_size,
learning_rate, momentum, damping)
def propose_directions(
self,
grads: Parameters,
velocities: Parameters,
learning_rate: Optional[jnp.ndarray],
momentum: Optional[jnp.ndarray],
) -> Tuple[Parameters, Parameters]:
"""Computes the vector proposals for the next step."""
del momentum # not used in this, but could be used in subclasses
preconditioned_grads = self.estimator.multiply_matpower(grads, -1)
if self.norm_constraint is not None:
assert learning_rate is not None
sq_norm_grads = utils.inner_product(preconditioned_grads, grads)
sq_norm_scaled_grads = sq_norm_grads * learning_rate**2
# We need to sync the norms here, because reduction can be
# non-deterministic. They specifically are on GPUs by default for better
# performance. Hence although grads and preconditioned_grads are synced,
# the inner_product operation can still produce different answers on
# different devices.
sq_norm_scaled_grads = utils.pmean_if_pmap(sq_norm_scaled_grads,
self.pmap_axis_name)
max_coefficient = jnp.sqrt(self.norm_constraint / sq_norm_scaled_grads)
coefficient = jnp.minimum(max_coefficient, 1)
preconditioned_grads = utils.scalar_mul(preconditioned_grads, coefficient)
return preconditioned_grads, velocities
def velocities_and_delta(
self,
velocities: Parameters,
vectors: Sequence[Parameters],
coefficients: Sequence[jnp.ndarray],
) -> Sequence[Parameters]:
"""Computes the new velocities and delta (update to parameters)."""
del velocities
assert len(vectors) == len(coefficients)
delta = utils.scalar_mul(vectors[0], coefficients[0])
for vi, wi in zip(vectors[1:], coefficients[1:]):
delta = jax.tree_map(jnp.add, delta, utils.scalar_mul(vi, wi))
return delta, delta