Skip to content

Commit

Permalink
ensure backward compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Niccolo-Ajroldi committed Oct 25, 2024
1 parent 1f59285 commit ce8eb18
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
8 changes: 4 additions & 4 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,10 @@ def init_optimizer_state(workload: Workload,
Dict[str, Tensor],
LossType,
OptimizerState,
Dict[str, Any],
List[Tuple[int, float]],
int,
RandomState
RandomState,
Optional[Dict[str, Any]]
],
UpdateReturn]

Expand All @@ -423,10 +423,10 @@ def update_params(workload: Workload,
batch: Dict[str, Tensor],
loss_type: LossType,
optimizer_state: OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState) -> UpdateReturn:
rng: RandomState,
train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
pass

Expand Down
11 changes: 9 additions & 2 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import datetime
import gc
import importlib
from inspect import signature
import itertools
import json
import os
import struct
import time
from types import MappingProxyType
from typing import Any, Dict, Optional, Tuple

from absl import app
Expand Down Expand Up @@ -273,6 +275,10 @@ def train_once(
hyperparameters,
opt_init_rng)
logging.info('Initializing metrics bundle.')

# Check if 'train_state' is in the function signature
needs_train_state = 'train_state' in signature(update_params).parameters

# Bookkeeping.
train_state = {
'validation_goal_reached': False,
Expand Down Expand Up @@ -357,10 +363,11 @@ def train_once(
batch=batch,
loss_type=workload.loss_type,
optimizer_state=optimizer_state,
train_state=train_state.copy(),
eval_results=eval_results,
global_step=global_step,
rng=update_rng)
rng=update_rng,
**({'train_state': MappingProxyType(train_state)}
if needs_train_state else {}))
except spec.TrainingCompleteError:
train_state['training_complete'] = True
global_step += 1
Expand Down

0 comments on commit ce8eb18

Please sign in to comment.