Skip to content

Commit

Permalink
add prepare_for_eval to spec.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Niccolo-Ajroldi committed Oct 18, 2024
1 parent 420b583 commit d9c4ee9
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,19 +406,6 @@ def init_optimizer_state(workload: Workload,
RandomState
],
UpdateReturn]
PrepareForEvalFn = Callable[[
Workload,
ParameterContainer,
ParameterTypeTree,
ModelAuxiliaryState,
Hyperparameters,
LossType,
OptimizerState,
List[Tuple[int, float]],
int,
RandomState
],
UpdateReturn]


# Each call to this function is considered a "step".
Expand All @@ -442,6 +429,36 @@ def update_params(workload: Workload,
pass


PrepareForEvalFn = Callable[[
Workload,
ParameterContainer,
ParameterTypeTree,
ModelAuxiliaryState,
Hyperparameters,
LossType,
OptimizerState,
List[Tuple[int, float]],
int,
RandomState
],
UpdateReturn]


# Prepare model and optimizer for evaluation.
def prepare_for_eval(workload: Workload,
current_param_container: ParameterContainer,
current_params_types: ParameterTypeTree,
model_state: ModelAuxiliaryState,
hyperparameters: Hyperparameters,
loss_type: LossType,
optimizer_state: OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState) -> UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
pass


DataSelectionFn = Callable[[
Workload,
Iterator[Dict[str, Any]],
Expand Down

0 comments on commit d9c4ee9

Please sign in to comment.