diff --git a/examples/plot_darts.py b/examples/plot_darts.py new file mode 100644 index 0000000000..ac6154b1f3 --- /dev/null +++ b/examples/plot_darts.py @@ -0,0 +1,26 @@ +import os +import logging +from naslib.defaults.trainer import Trainer +from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch +from naslib.search_spaces import NasBench301SearchSpace, SimpleCellSearchSpace + +from naslib.utils import set_seed, setup_logger, get_config_from_args + +config = get_config_from_args() # use --help so see the options +# config.search.batch_size = 128 +config.search.epochs = 1 +config.save_arch_weights = True +config.plot_arch_weights = True +config.save_arch_weights_path = f"{config.save}/save_arch" +set_seed(config.seed) + +logger = setup_logger(config.save + "/log.log") +logger.setLevel(logging.INFO) # default DEBUG is very verbose + +search_space = NasBench301SearchSpace() #SimpleCellSearchSpace() # use SimpleCellSearchSpace() for less heavy search + +optimizer = DARTSOptimizer(config) +optimizer.adapt_search_space(search_space) + +trainer = Trainer(optimizer, config) +trainer.search() \ No newline at end of file diff --git a/naslib/defaults/trainer.py b/naslib/defaults/trainer.py index ccffda1c7f..86776a9cb1 100644 --- a/naslib/defaults/trainer.py +++ b/naslib/defaults/trainer.py @@ -8,6 +8,10 @@ import torch import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +from pathlib import Path from fvcore.common.checkpoint import PeriodicCheckpointer from naslib.search_spaces.core.query_metrics import Metric @@ -57,8 +61,8 @@ def __init__(self, optimizer, config, lightweight_output=False): self.val_loss = utils.AverageMeter() n_parameters = optimizer.get_model_size() - # logger.info("param size = %fMB", n_parameters) - self.search_trajectory = utils.AttrDict( + logger.info("param size = %fMB", n_parameters) + self.errors_dict = utils.AttrDict( { "train_acc": [], "train_loss": [], @@ -73,7 +77,8 @@ def __init__(self, optimizer, config, lightweight_output=False): } ) - def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int], None]=None, report_incumbent=True): + def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int], None] = None, + report_incumbent=True): """ Start the architecture search. @@ -83,7 +88,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int resume_from (str): Checkpoint file to resume from. If not given then train from scratch. """ - logger.info("Beginning search") + logger.info("Start training") np.random.seed(self.config.search.seed) torch.manual_seed(self.config.search.seed) @@ -108,11 +113,29 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int for e in range(start_epoch, self.epochs): + # create the arch directory (without overwriting) + if hasattr(self.config, "save_arch_weights"): + if self.config.save_arch_weights: + Path(f"{self.config.save_arch_weights_path}/epoch_{e}").mkdir(parents=True, exist_ok=True) + start_time = time.time() self.optimizer.new_epoch(e) + arch_weights_lst = [] if self.optimizer.using_step_function: for step, data_train in enumerate(self.train_queue): + + # save arch weights to array of tensors + if hasattr(self.config, "save_arch_weights"): + if self.config.save_arch_weights: + if len(arch_weights_lst) == 0: + for alpha_i in self.optimizer.architectural_weights: + arch_weights_lst.append(torch.unsqueeze(alpha_i.detach(), dim=0)) + else: + for idx, alpha_i in enumerate(self.optimizer.architectural_weights): + arch_weights_lst[idx] = torch.cat( + (arch_weights_lst[idx], torch.unsqueeze(alpha_i.detach(), dim=0)), dim=0) + data_train = ( data_train[0].to(self.device), data_train[1].to(self.device, non_blocking=True), @@ -151,11 +174,11 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int end_time = time.time() - self.search_trajectory.train_acc.append(self.train_top1.avg) - self.search_trajectory.train_loss.append(self.train_loss.avg) - self.search_trajectory.valid_acc.append(self.val_top1.avg) - self.search_trajectory.valid_loss.append(self.val_loss.avg) - self.search_trajectory.runtime.append(end_time - start_time) + self.errors_dict.train_acc.append(self.train_top1.avg) + self.errors_dict.train_loss.append(self.train_loss.avg) + self.errors_dict.valid_acc.append(self.val_top1.avg) + self.errors_dict.valid_loss.append(self.val_loss.avg) + self.errors_dict.runtime.append(end_time - start_time) else: end_time = time.time() # TODO: nasbench101 does not have train_loss, valid_loss, test_loss implemented, so this is a quick fix for now @@ -168,28 +191,28 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int ) = self.optimizer.train_statistics(report_incumbent) train_loss, valid_loss, test_loss = -1, -1, -1 - self.search_trajectory.train_acc.append(train_acc) - self.search_trajectory.train_loss.append(train_loss) - self.search_trajectory.valid_acc.append(valid_acc) - self.search_trajectory.valid_loss.append(valid_loss) - self.search_trajectory.test_acc.append(test_acc) - self.search_trajectory.test_loss.append(test_loss) - self.search_trajectory.runtime.append(end_time - start_time) - self.search_trajectory.train_time.append(train_time) + self.errors_dict.train_acc.append(train_acc) + self.errors_dict.train_loss.append(train_loss) + self.errors_dict.valid_acc.append(valid_acc) + self.errors_dict.valid_loss.append(valid_loss) + self.errors_dict.test_acc.append(test_acc) + self.errors_dict.test_loss.append(test_loss) + self.errors_dict.runtime.append(end_time - start_time) + self.errors_dict.train_time.append(train_time) self.train_top1.avg = train_acc self.val_top1.avg = valid_acc self.periodic_checkpointer.step(e) anytime_results = self.optimizer.test_statistics() - # if anytime_results: + if anytime_results: # record anytime performance - # self.search_trajectory.arch_eval.append(anytime_results) - # log_every_n_seconds( - # logging.INFO, - # "Epoch {}, Anytime results: {}".format(e, anytime_results), - # n=5, - # ) + self.errors_dict.arch_eval.append(anytime_results) + log_every_n_seconds( + logging.INFO, + "Epoch {}, Anytime results: {}".format(e, anytime_results), + n=5, + ) self._log_to_json() @@ -198,6 +221,28 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int if after_epoch is not None: after_epoch(e) + logger.info(f"Saving architectural weight tensors: {self.config.save_arch_weights_path}/epoch_{e}") + + # writing arch weights to file and plotting + if hasattr(self.config, "save_arch_weights"): + if self.config.save_arch_weights: + if not Path(f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_0.pt').exists(): + for idx in range(len(arch_weights_lst)): + if hasattr(self.config, "plot_arch_weights"): + if self.config.plot_arch_weights: + self._plot_architectural_weights(idx, alpha_i=arch_weights_lst[idx], epoch_num=e) + torch.save(arch_weights_lst[idx], + f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') + else: + for idx in range(len(self.optimizer.architectural_weights)): + old_arch_weights = torch.load(f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') + arch_weights_lst[idx] = torch.cat((old_arch_weights, arch_weights_lst[idx]), dim=0) + if hasattr(self.config, "plot_arch_weights"): + if self.config.plot_arch_weights: + self._plot_architectural_weights(idx, alpha_i=arch_weights_lst[idx], epoch_num=e) + torch.save(arch_weights_lst[idx], + f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') + self.optimizer.after_training() if summary_writer is not None: @@ -240,9 +285,9 @@ def evaluate_oneshot(self, resume_from="", dataloader=None): end_time = time.time() - self.search_trajectory.valid_acc.append(self.val_top1.avg) - self.search_trajectory.valid_loss.append(self.val_loss.avg) - self.search_trajectory.runtime.append(end_time - start_time) + self.errors_dict.valid_acc.append(self.val_top1.avg) + self.errors_dict.valid_loss.append(self.val_loss.avg) + self.errors_dict.runtime.append(end_time - start_time) self._log_to_json() @@ -250,13 +295,13 @@ def evaluate_oneshot(self, resume_from="", dataloader=None): return self.val_top1.avg def evaluate( - self, - retrain:bool=True, - search_model:str="", - resume_from:str="", - best_arch:Graph=None, - dataset_api:object=None, - metric:Metric=None, + self, + retrain: bool = True, + search_model: str = "", + resume_from: str = "", + best_arch: Graph = None, + dataset_api: object = None, + metric: Metric = None, ): """ Evaluate the final architecture as given from the optimizer. @@ -284,7 +329,7 @@ def evaluate( self._setup_checkpointers(search_model) # required to load the architecture best_arch = self.optimizer.get_final_architecture() - logger.info(f"Final architecture hash: {best_arch.get_hash()}") + logger.info("Final architecture:\n" + best_arch.modules_str()) if best_arch.QUERYABLE: if metric is None: @@ -293,7 +338,6 @@ def evaluate( metric=metric, dataset=self.config.dataset, dataset_api=dataset_api ) logger.info("Queried results ({}): {}".format(metric, result)) - return result else: best_arch.to(self.device) if retrain: @@ -366,14 +410,14 @@ def evaluate( logits_train = best_arch(input_train) train_loss = loss(logits_train, target_train) if hasattr( - best_arch, "auxilary_logits" + best_arch, "auxilary_logits" ): # darts specific stuff log_first_n(logging.INFO, "Auxiliary is used", n=10) auxiliary_loss = loss( best_arch.auxilary_logits(), target_train ) train_loss += ( - self.config.evaluation.auxiliary_weight * auxiliary_loss + self.config.evaluation.auxiliary_weight * auxiliary_loss ) train_loss.backward() if grad_clip: @@ -395,9 +439,8 @@ def evaluate( if self.valid_queue: best_arch.eval() for i, (input_valid, target_valid) in enumerate( - self.valid_queue + self.valid_queue ): - input_valid = input_valid.to(self.device).float() target_valid = target_valid.to(self.device).float() @@ -453,8 +496,6 @@ def evaluate( ) ) - return top1.avg - @staticmethod def build_search_dataloaders(config): train_queue, valid_queue, test_queue, _, _ = utils.get_train_val_loaders( @@ -496,10 +537,12 @@ def build_eval_scheduler(optimizer, config): def _log_and_reset_accuracies(self, epoch, writer=None): logger.info( - "Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format( + "Epoch {} done. Train accuracy (top1, top5): {:.5f}, {:.5f}, Validation accuracy: {:.5f}, {:.5f}".format( epoch, self.train_top1.avg, + self.train_top5.avg, self.val_top1.avg, + self.val_top5.avg, ) ) @@ -550,7 +593,7 @@ def _prepare_dataloaders(self, config, mode="train"): self.test_queue = test_queue def _setup_checkpointers( - self, resume_from="", search=True, period=1, **add_checkpointables + self, resume_from="", search=True, period=1, **add_checkpointables ): """ Sets up a periodic chechkpointer which can be used to save checkpoints @@ -596,14 +639,27 @@ def _log_to_json(self): os.makedirs(self.config.save) if not self.lightweight_output: with codecs.open( - os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" + os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" ) as file: - json.dump(self.search_trajectory, file, separators=(",", ":")) + json.dump(self.errors_dict, file, separators=(",", ":")) else: with codecs.open( - os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" + os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" ) as file: - lightweight_dict = copy.deepcopy(self.search_trajectory) + lightweight_dict = copy.deepcopy(self.errors_dict) for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]: lightweight_dict.pop(key) json.dump([self.config, lightweight_dict], file, separators=(",", ":")) + + def _plot_architectural_weights(self, idx, alpha_i, epoch_num): + # Todo check if softmax is suitable here. In which range are the weights for e.g. GDAS + alpha_i = torch.softmax(alpha_i.detach(), dim=1).cpu().numpy() + g = sns.heatmap(alpha_i.T, cmap=sns.diverging_palette(230, 0, 90, 60, as_cmap=True)) + g.set_xticklabels(g.get_xticklabels(), rotation=60) + + plt.title(f"arch weights for operation {idx}") + plt.xlabel("steps") + plt.ylabel("alpha values") + plt.tight_layout() + plt.savefig(f"{self.config.save_arch_weights_path}/epoch_{epoch_num}/heatmap_{idx}.png") + plt.close()