diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 607f47ead..8722a441e 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -199,6 +199,7 @@ def update_params( batch: Dict[str, Tensor], loss_type: LossType, optimizer_state: OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: RandomState @@ -212,6 +213,7 @@ def update_params( - The `loss_fn` produces a loss per example and a summed loss (both only for one device), which both can be used. - Allowed to update state for the optimizer. - Uses the `model_fn` of the `workload` in order to decouple the loss from the model so that model outputs (forward passes) can be reused (by storing them in the optimizer state). +- The submission can access the elapsed training time and get further information about the evaluation through `train_state`. - The submission can access the target evaluation metric via the `workload` variable. - **A call to this function will be considered a step** - The time between a call to this function and the next call to this function will be considered the per-step time. diff --git a/README.md b/README.md index 5a1f10a33..516c8eb1b 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ python3 submission_runner.py \ --workload=mnist \ --experiment_dir=$HOME/experiments \ --experiment_name=my_first_experiment \ - --submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \ + --submission_path=reference_algorithms/paper_baselines/adamw/pytorch/submission.py \ --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json ``` diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 4f6c254bd..590f500fa 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -67,10 +67,13 @@ def update_batch_norm_fn(module: spec.ParameterContainer, ) if isinstance(module, bn_layers): if not update_batch_norm: - module.eval() - module.momentum_backup = module.momentum + if not hasattr(module, 'momentum_backup'): + module.momentum_backup = module.momentum + # module.momentum can be float or torch.Tensor. - module.momentum = 0. * module.momentum_backup + if torch.is_tensor(module.momentum_backup): + module.momentum = torch.zeros_like(module.momentum_backup) + else: + module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): module.momentum = module.momentum_backup - module.track_running_stats = update_batch_norm diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 285983957..7bc86b505 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -403,7 +403,8 @@ def init_optimizer_state(workload: Workload, OptimizerState, List[Tuple[int, float]], int, - RandomState + RandomState, + Optional[Dict[str, Any]] ], UpdateReturn] @@ -424,7 +425,8 @@ def update_params(workload: Workload, optimizer_state: OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: RandomState) -> UpdateReturn: + rng: RandomState, + train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" pass diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py index 834c93b7a..059352fb6 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py @@ -28,11 +28,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: bool = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..8268c6ca3 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -110,7 +110,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -119,14 +121,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py index 99a9b0513..34cd17440 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py @@ -84,11 +84,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: Optional[bool] = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index d8de214f5..2747fc2db 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -148,7 +148,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -157,14 +159,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index ed05f4335..cb6287c5e 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -454,7 +454,11 @@ def setup(self): self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + update_batch_norm, + use_running_average_bn): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -462,7 +466,12 @@ def __call__(self, inputs, input_paddings, train): momentum = self.config.batch_norm_momentum epsilon = self.config.batch_norm_epsilon - if train: + if use_running_average_bn: + mean = self.ra_mean.value + var = self.ra_var.value + + else: + # compute batch statistics mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) count_v = jnp.sum( @@ -478,16 +487,13 @@ def __call__(self, inputs, input_paddings, train): var = sum_vv / count_v - self.ra_mean.value = momentum * \ - self.ra_mean.value + (1 - momentum) * mean - self.ra_var.value = momentum * \ - self.ra_var.value + (1 - momentum) * var - else: - mean = self.ra_mean.value - var = self.ra_var.value + if update_batch_norm: + self.ra_mean.value = momentum * \ + self.ra_mean.value + (1 - momentum) * mean + self.ra_var.value = momentum * \ + self.ra_var.value + (1 - momentum) * var inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) - bn_output = (inputs - mean) * inv + self.beta bn_output *= 1.0 - padding @@ -517,7 +523,12 @@ class ConvolutionBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average_bn): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -546,7 +557,10 @@ def __call__(self, inputs, input_paddings, train): kernel_init=nn.initializers.xavier_uniform())( inputs) - inputs = BatchNorm(config)(inputs, input_paddings, train) + inputs = BatchNorm(config)(inputs, + input_paddings, + update_batch_norm, + use_running_average_bn) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': @@ -586,7 +600,12 @@ class ConformerBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -597,7 +616,12 @@ def __call__(self, inputs, input_paddings, train): inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, input_paddings, train) + ConvolutionBlock(config)(inputs, + input_paddings, + train, + update_batch_norm, + use_running_average + ) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( inputs, padding_mask, train) @@ -629,12 +653,23 @@ def setup(self): .use_dynamic_time_mask_max_frames) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None): config = self.config outputs = inputs output_paddings = input_paddings + # Set BN args if not supplied for backwards compatibility + if update_batch_norm is None: + update_batch_norm = train + if use_running_average_bn is None: + use_running_average_bn = not train + # Compute normalized log mel spectrograms from input audio signal. preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() outputs, output_paddings = preprocessor.MelFilterbankFrontend( @@ -660,7 +695,11 @@ def __call__(self, inputs, input_paddings, train): # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, output_paddings, train) + outputs = ConformerBlock(config)(outputs, + output_paddings, + train, + update_batch_norm, + use_running_average_bn) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..e362f973b 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -107,7 +107,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN @@ -118,7 +120,8 @@ def model_fn( input_paddings, train=True, rngs={'dropout' : rng}, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( @@ -126,7 +129,8 @@ def model_fn( inputs, input_paddings, train=False, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), model_state def _build_input_queue( diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index 502cb093e..61400806a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -40,7 +40,7 @@ class ConformerConfig: time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True input_dropout_rate: float = 0.1 - batch_norm_momentum: float = 0.999 + batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True attention_temperature: float = 1.0 @@ -369,10 +369,11 @@ def forward(self, inputs, input_paddings): mean = (masked_inp).sum(dim=(0, 1)) / count var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count - self.running_mean = self.momentum * self.running_mean + ( - 1 - self.momentum) * mean.detach() - self.running_var = self.momentum * self.running_var + ( - 1 - self.momentum) * var.detach() + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() + else: mean = self.running_mean var = self.running_var diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 8473fac0f..a0db6d607 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,5 +1,5 @@ import functools -from typing import Optional +from typing import Dict, Optional, Tuple from flax import jax_utils import jax @@ -56,6 +56,37 @@ def init_model_fn( params = jax_utils.replicate(params) return params, model_state + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + variables = {'params': params, **model_state} + inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] + is_train_mode = mode == spec.ForwardPassMode.TRAIN + if update_batch_norm or is_train_mode: + (logits, logit_paddings), new_model_state = self._model.apply( + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout' : rng}, + mutable=['batch_stats']) + return (logits, logit_paddings), new_model_state + else: + logits, logit_paddings = self._model.apply( + variables, + inputs, + input_paddings, + train=False, + mutable=False) + return (logits, logit_paddings), model_state + def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index a5ee3fa0a..bdf556f1c 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -36,7 +36,7 @@ class DeepspeechConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - batch_norm_momentum: float = 0.999 + batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 # If None, defaults to 0.1. input_dropout_rate: Optional[float] = 0.1 @@ -264,10 +264,10 @@ def forward(self, inputs, input_paddings): sum_ = dist_nn.all_reduce(sum_) var = sum_ / count - self.running_mean = self.momentum * self.running_mean + ( - 1 - self.momentum) * mean.detach() - self.running_var = self.momentum * self.running_var + ( - 1 - self.momentum) * var.detach() + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() else: mean = self.running_mean var = self.running_var diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 98193f01f..a235c50cd 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -252,20 +252,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 66fdc4ebb..06413f681 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -252,20 +252,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index ebc49d428..0e654d43c 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 524bc20af..dd0b8b076 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4f53afb56..a9f048f03 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -264,20 +264,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60a1f784d..4d3d2b341 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -264,20 +264,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index f8e87ec2a..5a5319957 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -236,20 +236,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 1de26417f..699b11268 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -236,20 +236,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 2971efe9a..97d6df9f1 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for CIFAR10.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -110,21 +110,24 @@ def _loss_fn(params): # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type del global_step + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state per_device_rngs = jax.random.split(rng, jax.local_device_count()) diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 358c6bffc..853064957 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for CIFAR10.""" -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch from torch.optim.lr_scheduler import CosineAnnealingLR @@ -53,21 +53,24 @@ def init_optimizer_state(workload: spec.Workload, return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 896609d51..6d05954a1 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for MNIST.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -75,20 +75,23 @@ def loss_fn(params): return new_optimizer_state, updated_params, new_model_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del global_step diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index f1601e606..d27d7f742 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for MNIST.""" -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch @@ -32,21 +32,24 @@ def init_optimizer_state(workload: spec.Workload, return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del hyperparameters del loss_type del current_params_types + del train_state del eval_results del global_step diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 2dd85c29b..efe238f26 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an Adafactor optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -110,20 +110,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index e6fef17dc..377468612 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for Adafactor in PyTorch.""" from functools import partial -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -190,20 +190,23 @@ def step(self, closure=None): return loss -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 80a963600..31e0a6801 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -110,20 +110,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 32353e5b4..27ceaeef7 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch.""" -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -51,20 +51,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 27d635ee9..be13ab540 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -118,20 +118,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7d0d8763e..d3b491e75 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -189,20 +189,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index cccb3c1b5..3eef23942 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with HeavyBall momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -144,20 +144,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index ec5c0b31c..cf474ebdd 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with HeavyBall momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import optax @@ -67,20 +67,23 @@ def create_lr_schedule_fn( return lr_schedule_fn -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 98193f01f..a235c50cd 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -252,20 +252,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index ebc49d428..0e654d43c 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index f3b0aeed4..553b3e478 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with Nesterov momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -144,20 +144,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index fe9154934..ba8c69e6c 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with Nesterov momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import optax @@ -67,20 +67,23 @@ def create_lr_schedule_fn( return lr_schedule_fn -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 85b3d7441..b5c7069cb 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SAM optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -197,20 +197,23 @@ def _loss_fn(params, update_batch_norm=True): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 2cab75972..b69945d51 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SAM optimizer with warmup+cosine LR in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -131,20 +131,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 9c6b66b7f..8f0b311a0 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a Shampoo optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -113,20 +113,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 2a641b520..51b20181b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,6 +1,6 @@ """Update submission function in Jax.""" import functools -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import jax from jax import lax @@ -69,20 +69,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index f9e40212b..6203c58b3 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -1,6 +1,6 @@ """Batch size and update submission functions in PyTorch.""" -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from absl import logging import torch @@ -12,20 +12,23 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..1a66acc58 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -17,11 +17,13 @@ import datetime import gc import importlib +from inspect import signature import itertools import json import os import struct import time +from types import MappingProxyType from typing import Any, Dict, Optional, Tuple from absl import app @@ -273,6 +275,10 @@ def train_once( hyperparameters, opt_init_rng) logging.info('Initializing metrics bundle.') + + # Check if 'train_state' is in the function signature + needs_train_state = 'train_state' in signature(update_params).parameters + # Bookkeeping. train_state = { 'validation_goal_reached': False, @@ -359,7 +365,9 @@ def train_once( optimizer_state=optimizer_state, eval_results=eval_results, global_step=global_step, - rng=update_rng) + rng=update_rng, + **({'train_state': MappingProxyType(train_state)} + if needs_train_state else {})) except spec.TrainingCompleteError: train_state['training_complete'] = True global_step += 1 diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 5ef195db5..ab98c9958 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,7 +4,7 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions for guidelines. """ -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from algorithmic_efficiency import spec @@ -22,17 +22,19 @@ def init_optimizer_state(workload: spec.Workload, pass -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """ Returns: (new_optimizer_state, update_fn) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 74c06e180..f107be8d7 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -408,6 +408,7 @@ def _test_submission(workload_name, workload_path=workload_metadata['workload_path'], workload_class_name=workload_metadata['workload_class_name'], return_class=True) + print(f'Workload class for {workload_name} is {workload_class}') submission_module_path = workloads.convert_filepath_to_module(submission_path) submission_module = importlib.import_module(submission_module_path) @@ -471,6 +472,7 @@ def _test_submission(workload_name, batch=batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, + train_state={}, eval_results=[], global_step=global_step, rng=update_rng)