Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Feature) Add the ability to use other optimizers and LRScheduler #78

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions pie/default_settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,22 @@
"word_dropout": 0.0, // input word dropout
"optimizer": "Adam", // optimizer type
"clip_norm": 5.0, // clip norm of gradient up to this value

"lr_scheduler": "ReduceLROnPlateau", // LR Scheduler to use: ReduceLROnPlateau,
"lr_delayed":0, // Use only with other schedulers than ReduceLRONPlateau for efficiency
"lr": 0.001,
"min_lr": 0.000001, // minimum learning rate
"checks_per_epoch": 1, // check model on dev-set so many times during epoch

// ReduceLROnPlateau parameters
"lr_factor": 0.75, // lr schedule (decrease lr by this factor after `lr_patience` epochs
// without improvement on dev-set data)
"min_lr": 0.000001, // minimum learning rate
"lr_patience": 2, // patience for lr schedule
"checks_per_epoch": 1, // check model on dev-set so many times during epoch

// CosineAnnealingLR parameters
"lr_T_max": 40,
// CosineAnnealingWarmRestarts parameters
"lr_T_0": 10, // Number of iteration before first restart

// * Model hyperparameters
"wemb_dim": 0, // word-level embedding dimension (if 0 no word embeddings are use)
Expand Down
143 changes: 123 additions & 20 deletions pie/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import os
import uuid
import logging
Expand All @@ -12,7 +11,10 @@

import torch
from torch import optim
from torch.optim.optimizer import Optimizer
from torch.nn.utils import clip_grad_norm_
import torch_optimizer as ext_optims
from typing import ClassVar

logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.INFO)

Expand Down Expand Up @@ -149,20 +151,25 @@ def get_weights(self):
return {task: self.tasks[task]['weight'] for task in self.tasks}


class LRScheduler(object):
def __init__(self, optimizer, threshold=0.0, **kwargs):
self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max', threshold=threshold, **kwargs)
class DelayerScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, delay: int, base_scheduler: torch.optim.lr_scheduler._LRScheduler):
self.nb_steps = -1
self.delay = delay
self.base_scheduler = base_scheduler
super(DelayerScheduler, self).__init__(optimizer)

def step(self, score):
self.lr_scheduler.step(score)
def step(self, *args, **kwargs):
self.nb_steps += 1
if self.steps > self.delay:
self.base_scheduler.step(*args, **kwargs)

def __repr__(self):
return '<LrScheduler lr="{:g}" steps="{}" patience="{}" threshold="{}"/>' \
.format(self.lr_scheduler.optimizer.param_groups[0]['lr'],
self.lr_scheduler.num_bad_epochs,
self.lr_scheduler.patience,
self.lr_scheduler.threshold)
@property
def waiting(self):
return self.steps <= self.delay

@property
def steps(self):
return self.nb_steps


class Trainer(object):
Expand All @@ -178,12 +185,99 @@ class Trainer(object):
report_freq
checks_per_epoch
"""

@staticmethod
def get_optimizer(optimizer_name: str) -> ClassVar[Optimizer]:
""" Allows for getting new optimizers from the torch-optimizer library without
breaking previous behaviour

:param optimizer_name: Optimizer Name, eg. Adam, SGD, Ranger
:return: Optimizer class
"""
if hasattr(optim, optimizer_name):
return getattr(optim, optimizer_name)
elif hasattr(ext_optims, optimizer_name):
return getattr(ext_optims, optimizer_name)

def print_lr_scheduler(self, lr_scheduler: optim.lr_scheduler._LRScheduler):
""" Display information using print about a LRScheduler

:param lr_scheduler:
:return:
"""
# If we use a Delayer, we print information about the delayer until it finishes waiting
if isinstance(lr_scheduler, DelayerScheduler):
if lr_scheduler.waiting:
print('<LRScheduler type="{}" lr="{:g}" delay="{}" steps="{}"/>'.format(
type(lr_scheduler).__name__,
self.optimizer.param_groups[0]['lr'],
lr_scheduler.delay,
lr_scheduler.steps
))
else:
self.print_lr_scheduler(lr_scheduler.base_scheduler)
# Continue to display former information for ReduceLROnPlateau
elif isinstance(lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau):
print('<LrScheduler type="{}" lr="{:g}" steps="{}" patience="{}" threshold="{}"/>'.format(
type(lr_scheduler).__name__,
self.optimizer.param_groups[0]['lr'],
lr_scheduler.num_bad_epochs,
lr_scheduler.patience,
lr_scheduler.threshold
))
# There are no specific information to display for some schedulers if not all
else:
print('<LrScheduler type="{}" lr="{:g}"/>'.format(
type(lr_scheduler).__name__,
self.optimizer.param_groups[0]['lr']
))

def get_scheduler(self, settings) -> optim.lr_scheduler._LRScheduler:
""" Initialize a LRScheduler based on settings

:param settings: Settings fed through JSON
:return: The LRScheduler required by the settings, disregarding delay
"""
if not self.optimizer:
raise Exception("Scheduler needs to be set after optimizer")
if settings.lr_scheduler == "ReduceLROnPlateau":
return optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='max', factor=settings.lr_factor,
patience=settings.lr_patience, min_lr=settings.min_lr
)
elif settings.lr_scheduler == "CosineAnnealingLR":
return optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=settings.lr_T_max, eta_min=settings.min_lr
)
elif settings.lr_scheduler == "CosineAnnealingWarmRestarts":
return optim.lr_scheduler.CosineAnnealingWarmRestarts(
self.optimizer, T_0=settings.lr_T_0, eta_min=settings.min_lr
)
else:
raise ValueError(f"Unknown scheduler {settings.lr_scheduler}")

def step_lr_scheduler(self, loss):
""" Apply a step to the LRScheduler.

Some scheduler use loss as the information for steps, some use the epoch_id.

:param loss: Loss computed from the TaskScheduler
"""
use_loss = isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau)
if isinstance(self.lr_scheduler, DelayerScheduler):
use_loss = isinstance(self.lr_scheduler.base_scheduler, optim.lr_scheduler.ReduceLROnPlateau)

if use_loss:
self.lr_scheduler.step(metrics=loss)
else:
self.lr_scheduler.step(epoch=None)

def __init__(self, settings, model, dataset, num_instances):
self.target_task = get_target_task(settings)
self.verbose = settings.verbose
self.dataset = dataset
self.model = model
self.optimizer = getattr(optim, settings.optimizer)(
self.optimizer = self.get_optimizer(settings.optimizer)(
model.parameters(), lr=settings.lr)
self.clip_norm = settings.clip_norm

Expand All @@ -199,9 +293,18 @@ def __init__(self, settings, model, dataset, num_instances):
self.check_freq = 0 # no checks

self.task_scheduler = TaskScheduler(settings)
self.lr_scheduler = LRScheduler(
self.optimizer, factor=settings.lr_factor,
patience=settings.lr_patience, min_lr=settings.min_lr)

lr_scheduler: optim.lr_scheduler._LRScheduler = self.get_scheduler(
settings
)
if settings.lr_delayed > 0:
self.lr_scheduler = DelayerScheduler(
optimizer=self.optimizer,
delay=settings.lr_delayed,
base_scheduler=lr_scheduler
)
else:
self.lr_scheduler = lr_scheduler

if settings.verbose:
print()
Expand All @@ -214,7 +317,7 @@ def __init__(self, settings, model, dataset, num_instances):
print()
print("::: LR schedule :::")
print()
print(self.lr_scheduler)
self.print_lr_scheduler(self.lr_scheduler)
print()

def weight_loss(self, loss):
Expand Down Expand Up @@ -278,12 +381,12 @@ def run_check(self, devset):
dev_scores['lm_bwd'] = dev_loss['lm_bwd']

self.task_scheduler.step(dev_scores, self.model)
self.lr_scheduler.step(dev_scores[self.target_task])
self.step_lr_scheduler(loss=dev_scores[self.target_task])

if self.verbose:
print(self.task_scheduler)
print()
print(self.lr_scheduler)
self.print_lr_scheduler(self.lr_scheduler)
print()

return dev_scores
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ torch>=1.3.1,<1.4.0
pyyaml==5.1b3
typing<4.0
click>=7.0,<8.0
torch-optimizer