From 03785451e763ca5fb1e834be40738db415c20b40 Mon Sep 17 00:00:00 2001 From: LindaCY <30959115+LindaCY@users.noreply.github.com> Date: Sun, 20 Jan 2019 00:24:37 +0800 Subject: [PATCH] Update trainer.py Add a training trick: Halve the learning rate if the performance on metrics not improving for [halve_lr_epochs] epochs, and then restart training by loading the previous best model. --- fastNLP/core/trainer.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 109315a3..13c5ca52 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -33,7 +33,7 @@ class Trainer(object): def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, - validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), + validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), halve_lr_epochs=-1, check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False, callbacks=None): """ @@ -49,6 +49,7 @@ def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch :param bool use_cuda: whether to use CUDA in training. :param str save_path: file path to save models :param Optimizer optimizer: an optimizer object + :param halve_lr_epochs: halve the learning rate if not imporving for [halve_lr_epochs] epochs. Default: -1 (never use it) :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\ `ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means it will raise error if some field are not used. @@ -106,6 +107,7 @@ def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch self.dev_data = dev_data # If None, No validation. self.model = model self.losser = losser + self.halve_lr_epochs = halve_lr_epochs self.metrics = metrics self.n_epochs = int(n_epochs) self.batch_size = int(batch_size) @@ -124,6 +126,9 @@ def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch self.use_tqdm = use_tqdm self.print_every = abs(self.print_every) + + for group in self.optimizer.param_groups: + self.lr = group['lr'] if self.dev_data is not None: self.tester = Tester(model=self.model, @@ -228,6 +233,7 @@ def _train(self): else: inner_tqdm = tqdm self.step = 0 + self.bad_valid = 0 start = time.time() data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) total_steps = data_iterator.num_batches * self.n_epochs @@ -258,8 +264,10 @@ def _train(self): self._update() # lr scheduler; lr_finder; one_cycle self.callback_manager.after_step(self.optimizer) - + + # Write the training loss summary self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) + for name, param in self.model.named_parameters(): if param.requires_grad: self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) @@ -316,6 +324,24 @@ def _do_validation(self, epoch, step): self.best_dev_perf = res self.best_dev_epoch = epoch self.best_dev_step = step + + # halve the learning rate if not improving for [halve_lr_epochs] epochs, and restart training from the best model. + else: + self.bad_valid += 1 + if self.halve_lr_epochs != -1: + if self.bad_valid >= self.halve_lr_epochs: + self.lr = self.lr / 2.0 + print("halve learning rate to {}".format(self.lr)) + model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) + load_succeed = self._load_model(self.model, model_name) + if load_succeed: + print("Reloaded the best model.") + else: + print("Fail to reload best model.") + self._set_lr(self.optimizer, self.lr) + self.bad_valid = 0 + print("bad valid: {}".format(self.bad_valid)) + # get validation results; adjust optimizer self.callback_manager.after_valid(res, self.metric_key, self.optimizer) return res @@ -409,7 +435,13 @@ def _better_eval_result(self, metrics): else: is_better = False return is_better - + + def _set_lr(self, optimizer, lr): + # if self.optimizer == "YFOptimizer": + # optimizer.set_lr_factor(lr) + # else: + for group in optimizer.param_groups: + group['lr'] = lr DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_NUM_BATCH = 2