Skip to content

Commit

Permalink
print step hint
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Sep 28, 2023
1 parent bb29602 commit 3738f35
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion algorithmic_efficiency/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def save_checkpoint(framework: str,
"""
if framework == 'jax':
model_params = jax.device_get(jax_utils.unreplicate(model_params))
opt_state, _ = optimizer_state
opt_state, _, _ = optimizer_state
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
model_state = jax.device_get(jax_utils.unreplicate(model_state))
else:
Expand Down
2 changes: 2 additions & 0 deletions reference_algorithms/target_setting_algorithms/jax_nadamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax
import jax.numpy as jnp
import optax
from absl import logging

from algorithmic_efficiency import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
Expand Down Expand Up @@ -152,6 +153,7 @@ def init_optimizer_state(workload: spec.Workload,
del rng

target_setting_step_hint = int(0.75 * workload.step_hint)
logging.info(f'target setting step hint: {target_setting_step_hint}')
lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint,
hyperparameters)

Expand Down

0 comments on commit 3738f35

Please sign in to comment.