Skip to content

Commit

Permalink
add train_state to all instances of `update_params', passing it by …
Browse files Browse the repository at this point in the history
…(shallow) copy in submission_runner
  • Loading branch information
Niccolo-Ajroldi committed Oct 3, 2024
1 parent a23b5ea commit e09bbf5
Show file tree
Hide file tree
Showing 33 changed files with 88 additions and 25 deletions.
1 change: 1 addition & 0 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def update_params(
batch: Dict[str, Tensor],
loss_type: LossType,
optimizer_state: OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState
Expand Down
2 changes: 2 additions & 0 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def init_optimizer_state(workload: Workload,
Dict[str, Tensor],
LossType,
OptimizerState,
Dict[str, Any],
List[Tuple[int, float]],
int,
RandomState
Expand All @@ -422,6 +423,7 @@ def update_params(workload: Workload,
batch: Dict[str, Tensor],
loss_type: LossType,
optimizer_state: OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState) -> UpdateReturn:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

from absl import logging
import torch
Expand Down Expand Up @@ -232,12 +232,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

from absl import logging
import torch
Expand Down Expand Up @@ -232,12 +232,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

from absl import logging
import torch
Expand Down Expand Up @@ -244,12 +244,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

from absl import logging
import torch
Expand Down Expand Up @@ -244,12 +244,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Training algorithm track submission functions for CIFAR10."""

import functools
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

from flax import jax_utils
import jax
Expand Down Expand Up @@ -118,13 +118,15 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del global_step
del train_state
del eval_results
optimizer_state, opt_update_fn = optimizer_state
per_device_rngs = jax.random.split(rng, jax.local_device_count())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Training algorithm track submission functions for CIFAR10."""

from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
Expand Down Expand Up @@ -61,13 +61,15 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del current_params_types
del hyperparameters
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Training algorithm track submission functions for MNIST."""

import functools
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

from flax import jax_utils
import jax
Expand Down Expand Up @@ -83,12 +83,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del global_step

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Training algorithm track submission functions for MNIST."""

from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

import torch

Expand Down Expand Up @@ -40,13 +40,15 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del hyperparameters
del loss_type
del current_params_types
del train_state
del eval_results
del global_step

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an Adafactor optimizer with warmup+cosine LR in Jax."""

import functools
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

from flax import jax_utils
import jax
Expand Down Expand Up @@ -118,12 +118,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for Adafactor in PyTorch."""

from functools import partial
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

from absl import logging
import torch
Expand Down Expand Up @@ -198,12 +198,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
4 changes: 3 additions & 1 deletion reference_algorithms/paper_baselines/adamw/jax/submission.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an AdamW optimizer with warmup+cosine LR in Jax."""

import functools
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

from flax import jax_utils
import jax
Expand Down Expand Up @@ -118,12 +118,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch."""

from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

from absl import logging
import torch
Expand Down Expand Up @@ -59,12 +59,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
4 changes: 3 additions & 1 deletion reference_algorithms/paper_baselines/lamb/jax/submission.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for a LAMB optimizer with warmup+cosine LR in Jax."""

import functools
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Tuple, Any

from flax import jax_utils
import jax
Expand Down Expand Up @@ -126,12 +126,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Loading

0 comments on commit e09bbf5

Please sign in to comment.