diff --git a/examples/example_runner.ipynb b/examples/example_runner.ipynb index bdb2127bc..ea4581e22 100644 --- a/examples/example_runner.ipynb +++ b/examples/example_runner.ipynb @@ -79,6 +79,7 @@ } ], "source": [ + "from naslib.utils import utils\n", "config = utils.get_config_from_args(config_type='nas')\n", "\n", "logger = setup_logger(config.save + \"/log.log\")\n", diff --git a/examples/plot_save_arch_weights.py b/examples/plot_save_arch_weights.py new file mode 100644 index 000000000..15dd8dfac --- /dev/null +++ b/examples/plot_save_arch_weights.py @@ -0,0 +1,48 @@ +import os +import logging +from naslib.defaults.trainer import Trainer +from naslib.optimizers import DARTSOptimizer, GDASOptimizer, DrNASOptimizer +from naslib.search_spaces import NasBench101SearchSpace, NasBench201SearchSpace, NasBench301SearchSpace + +from naslib.utils import set_seed, setup_logger, get_config_from_args, create_exp_dir +from naslib.utils.vis import plot_architectural_weights + +config = get_config_from_args() # use --help so see the options +config.search.epochs = 50 +config.save_arch_weights = True +config.plot_arch_weights = True +config.optimizer = 'gdas' +config.search_space = 'nasbench301' +config.save = "{}/{}/{}/{}/{}".format( + config.out_dir, config.search_space, config.dataset, config.optimizer, config.seed +) +create_exp_dir(config.save) +create_exp_dir(config.save + "/search") # required for the checkpoints +create_exp_dir(config.save + "/eval") + +optimizers = { + 'gdas': GDASOptimizer(config), + 'darts': DARTSOptimizer(config), + 'drnas': DrNASOptimizer(config), +} + +search_spaces = { + 'nasbench101': NasBench101SearchSpace(), + 'nasbench201': NasBench201SearchSpace(), + 'nasbench301': NasBench301SearchSpace(), +} + +set_seed(config.seed) + +logger = setup_logger(config.save + "/log.log") +logger.setLevel(logging.INFO) # default DEBUG is very verbose + +search_space = search_spaces[config.search_space] + +optimizer = optimizers[config.optimizer] +optimizer.adapt_search_space(search_space) + +trainer = Trainer(optimizer, config) +# trainer.search() + +plot_architectural_weights(config, optimizer) diff --git a/naslib/defaults/trainer.py b/naslib/defaults/trainer.py index ccffda1c7..35e60fb63 100644 --- a/naslib/defaults/trainer.py +++ b/naslib/defaults/trainer.py @@ -8,12 +8,14 @@ import torch import numpy as np +from pathlib import Path from fvcore.common.checkpoint import PeriodicCheckpointer from naslib.search_spaces.core.query_metrics import Metric from naslib.utils import utils from naslib.utils.logging import log_every_n_seconds, log_first_n +from naslib.utils.vis import plot_architectural_weights from typing import Callable from .additional_primitives import DropPathWrapper @@ -57,7 +59,7 @@ def __init__(self, optimizer, config, lightweight_output=False): self.val_loss = utils.AverageMeter() n_parameters = optimizer.get_model_size() - # logger.info("param size = %fMB", n_parameters) + logger.info("param size = %fMB", n_parameters) self.search_trajectory = utils.AttrDict( { "train_acc": [], @@ -106,6 +108,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int self.config ) + arch_weights = [] for e in range(start_epoch, self.epochs): start_time = time.time() @@ -113,6 +116,16 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int if self.optimizer.using_step_function: for step, data_train in enumerate(self.train_queue): + + # save arch weights to array of tensors + if self.config.save_arch_weights is True: + if len(arch_weights) == 0: + for edge_weights in self.optimizer.architectural_weights: + arch_weights.append(torch.unsqueeze(edge_weights.detach(), dim=0)) + else: + for i, edge_weights in enumerate(self.optimizer.architectural_weights): + arch_weights[i] = torch.cat((arch_weights[i], torch.unsqueeze(edge_weights.detach(), dim=0)), dim=0) + data_train = ( data_train[0].to(self.device), data_train[1].to(self.device, non_blocking=True), @@ -183,13 +196,13 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int anytime_results = self.optimizer.test_statistics() # if anytime_results: - # record anytime performance - # self.search_trajectory.arch_eval.append(anytime_results) - # log_every_n_seconds( - # logging.INFO, - # "Epoch {}, Anytime results: {}".format(e, anytime_results), - # n=5, - # ) + # # record anytime performance + # self.search_trajectory.arch_eval.append(anytime_results) + # log_every_n_seconds( + # logging.INFO, + # "Epoch {}, Anytime results: {}".format(e, anytime_results), + # n=5, + # ) self._log_to_json() @@ -198,6 +211,13 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int if after_epoch is not None: after_epoch(e) + # save and possibly plot architectural weights + logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt") + if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights: + torch.save(arch_weights, f'{self.config.save}/arch_weights.pt') + if hasattr(self.config, "plot_arch_weights") and self.config.plot_arch_weights: + plot_architectural_weights(self.config, self.optimizer) + self.optimizer.after_training() if summary_writer is not None: @@ -366,7 +386,7 @@ def evaluate( logits_train = best_arch(input_train) train_loss = loss(logits_train, target_train) if hasattr( - best_arch, "auxilary_logits" + best_arch, "auxilary_logits" ): # darts specific stuff log_first_n(logging.INFO, "Auxiliary is used", n=10) auxiliary_loss = loss( @@ -455,6 +475,7 @@ def evaluate( return top1.avg + @staticmethod def build_search_dataloaders(config): train_queue, valid_queue, test_queue, _, _ = utils.get_train_val_loaders( @@ -496,10 +517,10 @@ def build_eval_scheduler(optimizer, config): def _log_and_reset_accuracies(self, epoch, writer=None): logger.info( - "Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format( + "Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format( epoch, self.train_top1.avg, - self.val_top1.avg, + self.val_top1.avg ) ) @@ -606,4 +627,4 @@ def _log_to_json(self): lightweight_dict = copy.deepcopy(self.search_trajectory) for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]: lightweight_dict.pop(key) - json.dump([self.config, lightweight_dict], file, separators=(",", ":")) + json.dump([self.config, lightweight_dict], file, separators=(",", ":")) \ No newline at end of file diff --git a/naslib/utils/__init__.py b/naslib/utils/__init__.py index db300252a..37f300f83 100644 --- a/naslib/utils/__init__.py +++ b/naslib/utils/__init__.py @@ -9,7 +9,8 @@ parse_args, get_train_val_loaders, get_project_root, - compute_scores + compute_scores, + create_exp_dir ) from .logging import setup_logger from .get_dataset_api import get_dataset_api, get_zc_benchmark_api, load_sampled_architectures diff --git a/naslib/utils/utils.py b/naslib/utils/utils.py index 999540c82..efb48ba48 100644 --- a/naslib/utils/utils.py +++ b/naslib/utils/utils.py @@ -322,8 +322,8 @@ def get_train_val_loaders(config, mode="train"): data = config.data dataset = config.dataset seed = config.search.seed - batch_size = config.batch_size - train_portion = config.train_portion + batch_size = config.batch_size if hasattr(config, "batch_size") else config.search.batch_size + train_portion = config.train_portion if hasattr(config, "train_portion") else config.search.train_portion config = config.search if mode == "train" else config.evaluation if dataset == "cifar10": train_transform, valid_transform = _data_transforms_cifar10(config) diff --git a/naslib/utils/vis/__init__.py b/naslib/utils/vis/__init__.py new file mode 100644 index 000000000..868c86a41 --- /dev/null +++ b/naslib/utils/vis/__init__.py @@ -0,0 +1,3 @@ +from .utils import ( + plot_architectural_weights, +) \ No newline at end of file diff --git a/naslib/utils/vis/utils.py b/naslib/utils/vis/utils.py new file mode 100644 index 000000000..a9504ec10 --- /dev/null +++ b/naslib/utils/vis/utils.py @@ -0,0 +1,76 @@ +import logging +import torch +import numpy as np + +import matplotlib.pyplot as plt +from matplotlib.cm import ScalarMappable +import seaborn as sns + +logger = logging.getLogger(__name__) + +def plot_architectural_weights(config, optimizer): + all_weights = torch.load(f'{config.save}/arch_weights.pt') # load alphas + + # unpack search space information + alpha_dict = {} + min_soft, max_soft = np.inf, -np.inf + for graph in optimizer.graph._get_child_graphs(single_instances=True): + for edge_weights, (u, v, edge_data) in zip(all_weights, graph.edges.data()): + + if edge_data.has("alpha"): + total_steps, num_alphas = edge_weights.shape + steps_per_epoch = total_steps // config.search.epochs + disc_weights = torch.mean(edge_weights.detach().reshape(-1, steps_per_epoch, num_alphas), axis=1).cpu() + soft_weights = torch.softmax(disc_weights, dim=-1).numpy() + + cell_name = edge_data['cell_name'] if hasattr(edge_data, 'cell_name') else "" + alpha_dict[(u, v, cell_name)] = {} + alpha_dict[(u, v, cell_name)]['op_names'] = [op.get_op_name for op in edge_data.op.get_embedded_ops()] + alpha_dict[(u, v, cell_name)]['alphas'] = soft_weights + + min_soft = min(min_soft, np.min(soft_weights)) + max_soft = max(max_soft, np.max(soft_weights)) + + max_rows = 4 # plot heatmaps in increments of n_rows edges + for start_id in range(0, len(alpha_dict.keys()), max_rows): + + # calculate number of rows in plot + n_rows = min(max_rows, len(alpha_dict.keys())-start_id) + logger.info(f"Creating plot {config.save}/arch_weights_{start_id+1}to{start_id+n_rows}.png") + + # define figure and axes and NASLib colormap + fig, axes = plt.subplots(nrows=n_rows, figsize=(10, max_rows)) + cmap = sns.diverging_palette(230, 0, 90, 60, as_cmap=True) + + # iterate over arch weights and create heatmaps + for ax_id, (u, v, cell_name) in enumerate(list(alpha_dict.keys())[start_id:start_id+n_rows]): + map = sns.heatmap( + alpha_dict[u, v, cell_name]['alphas'].T, + cmap=cmap, + vmin=min_soft, + vmax=max_soft, + ax=axes[ax_id], + cbar=True + ) + + op_names = alpha_dict[(u, v, cell_name)]['op_names'] + + if ax_id < n_rows-1: + axes[ax_id].set_xticks([]) + axes[ax_id].set_ylabel(f"{u, v}", fontdict=dict(fontsize=6)) + axes[ax_id].set_yticks(np.arange(len(op_names)) + 0.5) + fontsize = max(6, 40/len(op_names)) + axes[ax_id].set_yticklabels(op_names, rotation=360, fontdict=dict(fontsize=fontsize)) + if cell_name != "": + axes[ax_id].set_title(cell_name, fontdict=dict(fontsize=6)) + cbar = map.collections[0].colorbar + cbar.ax.tick_params(labelsize=6) + cbar.ax.set_title('softmax', fontdict=dict(fontsize=6)) + + axes[ax_id].xaxis.set_tick_params(labelsize=6) + axes[ax_id].set_xlabel("Epoch", fontdict=dict(fontsize=6)) + + fig.suptitle(f"optimizer: {config.optimizer}, search space: {config.search_space}, dataset: {config.dataset}, seed: {config.seed}") + fig.tight_layout() + fig.savefig(f"{config.save}/arch_weights_{start_id+1}to{start_id+n_rows}.png", dpi=300) +