diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 5448ff1f2..95bea68aa 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SAM optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Optional, Tuple, Any +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index e66b1ab23..a98d134fc 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,6 +1,6 @@ """Update submission function in Jax.""" import functools -from typing import Dict, List, Tuple, Any +from typing import Any, Dict, List, Tuple import jax from jax import lax diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index c031f3ac4..586429e37 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -1,6 +1,6 @@ """Batch size and update submission functions in PyTorch.""" -from typing import Dict, List, Tuple, Any +from typing import Any, Dict, List, Tuple from absl import logging import torch