Skip to content

Commit

Permalink
adding train_state to all submissions
Browse files Browse the repository at this point in the history
  • Loading branch information
Niccolo-Ajroldi committed Oct 25, 2024
1 parent ce8eb18 commit 5a06a0d
Show file tree
Hide file tree
Showing 30 changed files with 107 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -232,10 +232,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -232,10 +232,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -244,10 +244,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -244,10 +244,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Training algorithm track submission functions for CIFAR10."""

import functools
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Training algorithm track submission functions for CIFAR10."""

from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
Expand Down Expand Up @@ -61,10 +61,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del current_params_types
del hyperparameters
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Training algorithm track submission functions for MNIST."""

import functools
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -83,10 +83,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Training algorithm track submission functions for MNIST."""

from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -40,10 +40,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del hyperparameters
del loss_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an Adafactor optimizer with warmup+cosine LR in Jax."""

import functools
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for Adafactor in PyTorch."""

from functools import partial
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -198,10 +198,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
7 changes: 4 additions & 3 deletions reference_algorithms/paper_baselines/adamw/jax/submission.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an AdamW optimizer with warmup+cosine LR in Jax."""

import functools
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch."""

from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -59,10 +59,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
7 changes: 4 additions & 3 deletions reference_algorithms/paper_baselines/lamb/jax/submission.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for a LAMB optimizer with warmup+cosine LR in Jax."""

import functools
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -126,10 +126,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for a LAMB optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -197,10 +197,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,11 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Loading

0 comments on commit 5a06a0d

Please sign in to comment.