From 86114ef970c832ea5b8ed15c47856e6d6d325df3 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 25 Oct 2024 11:47:47 +0200 Subject: [PATCH] fix missing import Optional --- reference_algorithms/paper_baselines/momentum/jax/submission.py | 2 +- .../paper_baselines/momentum/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/nesterov/jax/submission.py | 2 +- .../paper_baselines/nesterov/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/sam/pytorch/submission.py | 2 +- .../target_setting_algorithms/jax_submission_base.py | 2 +- .../target_setting_algorithms/pytorch_submission_base.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index b173ba8ba..346abe652 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with HeavyBall momentum optimizer in Jax.""" import functools -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index c063f0a64..090a8bc01 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with HeavyBall momentum optimizer in PyTorch.""" -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import optax diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 35ef2bfa8..fa5329778 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with Nesterov momentum optimizer in Jax.""" import functools -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index 0b7cc570b..ce0854f7d 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with Nesterov momentum optimizer in PyTorch.""" -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import optax diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index a793673f9..e9c9c9bc4 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SAM optimizer with warmup+cosine LR in PyTorch.""" -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import torch diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 999422fb0..6914da94e 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,6 +1,6 @@ """Update submission function in Jax.""" import functools -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import jax from jax import lax diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 92f222a18..606253e32 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -1,6 +1,6 @@ """Batch size and update submission functions in PyTorch.""" -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from absl import logging import torch