Skip to content

Commit

Permalink
tests and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Justin Sanders committed Sep 28, 2023
1 parent f136798 commit 5557d97
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 19 deletions.
32 changes: 13 additions & 19 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
tb_summarywriter: Optional[
torch.utils.tensorboard.SummaryWriter
] = None,
lr_schedule = None,
lr_schedule=None,
warmup_iters: int = 100_000,
max_iters: int = 600_000,
out_writer: Optional[ms_io.MztabWriter] = None,
Expand Down Expand Up @@ -968,27 +968,23 @@ def configure_optimizers(
# Add linear learning rate scheduler for warmup
lr_schedulers = [
torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1e-10,
total_iters=self.warmup_iters)
]
if self.lr_schedule == 'cosine':
optimizer, start_factor=1e-10, total_iters=self.warmup_iters
)
]
if self.lr_schedule == "cosine":
lr_schedulers.append(
CosineScheduler(
optimizer,
max_iters=self.max_iters
)
CosineScheduler(optimizer, max_iters=self.max_iters)
)
elif self.lr_schedule == 'linear':
elif self.lr_schedule == "linear":
lr_schedulers.append(
torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1,
optimizer,
start_factor=1,
end_factor=0,
total_iters=self.max_iters
)
total_iters=self.max_iters,
)
)
#Combine learning rate schedulers
# Combine learning rate schedulers
lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler(lr_schedulers)
# Apply learning rate scheduler per step.
return [optimizer], {"scheduler": lr_scheduler, "interval": "step"}
Expand All @@ -1008,9 +1004,7 @@ class CosineScheduler(torch.optim.lr_scheduler._LRScheduler):
The total number of iterations.
"""

def __init__(
self, optimizer: torch.optim.Optimizer, max_iters: int
):
def __init__(self, optimizer: torch.optim.Optimizer, max_iters: int):
self.max_iters = max_iters
super().__init__(optimizer)

Expand Down
42 changes: 42 additions & 0 deletions tests/unit_tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,45 @@ def test_run_map(mgf_small):
out_writer.set_ms_run([os.path.abspath(mgf_small.name)])
assert os.path.basename(mgf_small.name) not in out_writer._run_map
assert os.path.abspath(mgf_small.name) in out_writer._run_map


def test_lr_schedule():
"""Test that the learning rate schedule is setup correctly"""

# Constant lr schedule is setup correctly
model = Spec2Pep(lr_schedule="constant")
assert model.lr_schedule is not None
assert model.lr_schedule == "constant"

# Learning rate schedule is applied every step rather than epoch
_, schedule = model.configure_optimizers()
assert schedule["interval"] == "step"

# Constant lr schedule includes only a warmup period
schedulers = schedule["scheduler"].state_dict()["_schedulers"]
assert len(schedulers) == 1

# Default warmup period lasts correct number of iters
assert schedulers[0]["start_factor"] == 1e-10
assert schedulers[0]["end_factor"] == 1
assert schedulers[0]["total_iters"] == 100000

# Linear lr schedule with custom warmup and max iters
model = Spec2Pep(lr_schedule="linear", warmup_iters=10, max_iters=100)
_, schedule = model.configure_optimizers()
schedulers = schedule["scheduler"].state_dict()["_schedulers"]

assert len(schedulers) == 2
assert schedulers[0]["total_iters"] == 10

assert schedulers[1]["start_factor"] == 1
assert schedulers[1]["end_factor"] == 0
assert schedulers[1]["total_iters"] == 100

# Cosine lr schedule
model = Spec2Pep(lr_schedule="cosine")
_, schedule = model.configure_optimizers()
schedulers = schedule["scheduler"].state_dict()["_schedulers"]

assert len(schedulers) == 2
assert schedulers[1]["_last_lr"][0] == [0.001]

0 comments on commit 5557d97

Please sign in to comment.