Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plotting convergence of alpha weights in One-Shot Optimizers #154

Open
wants to merge 7 commits into
base: Develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/example_runner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
}
],
"source": [
"from naslib.utils import utils\n",
"config = utils.get_config_from_args(config_type='nas')\n",
"\n",
"logger = setup_logger(config.save + \"/log.log\")\n",
Expand Down
48 changes: 48 additions & 0 deletions examples/plot_save_arch_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import logging
from naslib.defaults.trainer import Trainer
from naslib.optimizers import DARTSOptimizer, GDASOptimizer, DrNASOptimizer
from naslib.search_spaces import NasBench101SearchSpace, NasBench201SearchSpace, NasBench301SearchSpace

from naslib.utils import set_seed, setup_logger, get_config_from_args, create_exp_dir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried running this file and it crashed because create_exp_dir is not imported in the __init__.py of utils

from naslib.utils.vis import plot_architectural_weights

config = get_config_from_args() # use --help so see the options
config.search.epochs = 50
config.save_arch_weights = True
config.plot_arch_weights = True
config.optimizer = 'gdas'
config.search_space = 'nasbench301'
config.save = "{}/{}/{}/{}/{}".format(
config.out_dir, config.search_space, config.dataset, config.optimizer, config.seed
)
create_exp_dir(config.save)
create_exp_dir(config.save + "/search") # required for the checkpoints
create_exp_dir(config.save + "/eval")

optimizers = {
'gdas': GDASOptimizer(config),
'darts': DARTSOptimizer(config),
'drnas': DrNASOptimizer(config),
}

search_spaces = {
'nasbench101': NasBench101SearchSpace(),
'nasbench201': NasBench201SearchSpace(),
'nasbench301': NasBench301SearchSpace(),
}

set_seed(config.seed)

logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO) # default DEBUG is very verbose

search_space = search_spaces[config.search_space]

optimizer = optimizers[config.optimizer]
optimizer.adapt_search_space(search_space)

trainer = Trainer(optimizer, config)
# trainer.search()

plot_architectural_weights(config, optimizer)
45 changes: 33 additions & 12 deletions naslib/defaults/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import torch
import numpy as np

from pathlib import Path
from fvcore.common.checkpoint import PeriodicCheckpointer

from naslib.search_spaces.core.query_metrics import Metric

from naslib.utils import utils
from naslib.utils.logging import log_every_n_seconds, log_first_n
from naslib.utils.vis import plot_architectural_weights

from typing import Callable
from .additional_primitives import DropPathWrapper
Expand Down Expand Up @@ -57,7 +59,7 @@ def __init__(self, optimizer, config, lightweight_output=False):
self.val_loss = utils.AverageMeter()

n_parameters = optimizer.get_model_size()
# logger.info("param size = %fMB", n_parameters)
logger.info("param size = %fMB", n_parameters)
self.search_trajectory = utils.AttrDict(
{
"train_acc": [],
Expand Down Expand Up @@ -106,13 +108,24 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
self.config
)

arch_weights = []
for e in range(start_epoch, self.epochs):

start_time = time.time()
self.optimizer.new_epoch(e)

if self.optimizer.using_step_function:
for step, data_train in enumerate(self.train_queue):

# save arch weights to array of tensors
if self.config.save_arch_weights is True:
if len(arch_weights) == 0:
for edge_weights in self.optimizer.architectural_weights:
arch_weights.append(torch.unsqueeze(edge_weights.detach(), dim=0))
else:
for i, edge_weights in enumerate(self.optimizer.architectural_weights):
arch_weights[i] = torch.cat((arch_weights[i], torch.unsqueeze(edge_weights.detach(), dim=0)), dim=0)

data_train = (
data_train[0].to(self.device),
data_train[1].to(self.device, non_blocking=True),
Expand Down Expand Up @@ -183,13 +196,13 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int

anytime_results = self.optimizer.test_statistics()
# if anytime_results:
# record anytime performance
# self.search_trajectory.arch_eval.append(anytime_results)
# log_every_n_seconds(
# logging.INFO,
# "Epoch {}, Anytime results: {}".format(e, anytime_results),
# n=5,
# )
# # record anytime performance
# self.search_trajectory.arch_eval.append(anytime_results)
# log_every_n_seconds(
# logging.INFO,
# "Epoch {}, Anytime results: {}".format(e, anytime_results),
# n=5,
# )

self._log_to_json()

Expand All @@ -198,6 +211,13 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
if after_epoch is not None:
after_epoch(e)

# save and possibly plot architectural weights
logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt")
if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.config.save_arch_weights is True for better readability

torch.save(arch_weights, f'{self.config.save}/arch_weights.pt')
if hasattr(self.config, "plot_arch_weights") and self.config.plot_arch_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.config.plot_arch_weights is True

plot_architectural_weights(self.config, self.optimizer)

self.optimizer.after_training()

if summary_writer is not None:
Expand Down Expand Up @@ -366,7 +386,7 @@ def evaluate(
logits_train = best_arch(input_train)
train_loss = loss(logits_train, target_train)
if hasattr(
best_arch, "auxilary_logits"
best_arch, "auxilary_logits"
): # darts specific stuff
log_first_n(logging.INFO, "Auxiliary is used", n=10)
auxiliary_loss = loss(
Expand Down Expand Up @@ -455,6 +475,7 @@ def evaluate(

return top1.avg
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has this line been removed?



@staticmethod
def build_search_dataloaders(config):
train_queue, valid_queue, test_queue, _, _ = utils.get_train_val_loaders(
Expand Down Expand Up @@ -496,10 +517,10 @@ def build_eval_scheduler(optimizer, config):

def _log_and_reset_accuracies(self, epoch, writer=None):
logger.info(
"Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format(
"Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format(
epoch,
self.train_top1.avg,
self.val_top1.avg,
self.val_top1.avg
)
)

Expand Down Expand Up @@ -606,4 +627,4 @@ def _log_to_json(self):
lightweight_dict = copy.deepcopy(self.search_trajectory)
for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]:
lightweight_dict.pop(key)
json.dump([self.config, lightweight_dict], file, separators=(",", ":"))
json.dump([self.config, lightweight_dict], file, separators=(",", ":"))
3 changes: 2 additions & 1 deletion naslib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
parse_args,
get_train_val_loaders,
get_project_root,
compute_scores
compute_scores,
create_exp_dir
)
from .logging import setup_logger
from .get_dataset_api import get_dataset_api, get_zc_benchmark_api, load_sampled_architectures
4 changes: 2 additions & 2 deletions naslib/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ def get_train_val_loaders(config, mode="train"):
data = config.data
dataset = config.dataset
seed = config.search.seed
batch_size = config.batch_size
train_portion = config.train_portion
batch_size = config.batch_size if hasattr(config, "batch_size") else config.search.batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a bug fix independent of the visualization code. If so, create a new PR, or push the fix directly to Develop.

train_portion = config.train_portion if hasattr(config, "train_portion") else config.search.train_portion
config = config.search if mode == "train" else config.evaluation
if dataset == "cifar10":
train_transform, valid_transform = _data_transforms_cifar10(config)
Expand Down
3 changes: 3 additions & 0 deletions naslib/utils/vis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .utils import (
plot_architectural_weights,
)
76 changes: 76 additions & 0 deletions naslib/utils/vis/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging
import torch
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import

import seaborn as sns

logger = logging.getLogger(__name__)

def plot_architectural_weights(config, optimizer):
all_weights = torch.load(f'{config.save}/arch_weights.pt') # load alphas

# unpack search space information
alpha_dict = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid data-structure name in var name

min_soft, max_soft = np.inf, -np.inf
for graph in optimizer.graph._get_child_graphs(single_instances=True):
for edge_weights, (u, v, edge_data) in zip(all_weights, graph.edges.data()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a bug here. all_weights is a list of size 28 (in case of nb301), each a tensor of size (n_steps, n_operations). The first 14 are from the normal cells, while the next 14 are from the reduction cells. The loop assigns the same alphas for both normal and reduction cells in line 29.
Screenshot 2023-02-01 at 3 47 40 PM


if edge_data.has("alpha"):
total_steps, num_alphas = edge_weights.shape
steps_per_epoch = total_steps // config.search.epochs
disc_weights = torch.mean(edge_weights.detach().reshape(-1, steps_per_epoch, num_alphas), axis=1).cpu()
soft_weights = torch.softmax(disc_weights, dim=-1).numpy()

cell_name = edge_data['cell_name'] if hasattr(edge_data, 'cell_name') else ""
alpha_dict[(u, v, cell_name)] = {}
alpha_dict[(u, v, cell_name)]['op_names'] = [op.get_op_name for op in edge_data.op.get_embedded_ops()]
alpha_dict[(u, v, cell_name)]['alphas'] = soft_weights

min_soft = min(min_soft, np.min(soft_weights))
max_soft = max(max_soft, np.max(soft_weights))

max_rows = 4 # plot heatmaps in increments of n_rows edges
for start_id in range(0, len(alpha_dict.keys()), max_rows):

# calculate number of rows in plot
n_rows = min(max_rows, len(alpha_dict.keys())-start_id)
logger.info(f"Creating plot {config.save}/arch_weights_{start_id+1}to{start_id+n_rows}.png")

# define figure and axes and NASLib colormap
fig, axes = plt.subplots(nrows=n_rows, figsize=(10, max_rows))
cmap = sns.diverging_palette(230, 0, 90, 60, as_cmap=True)

# iterate over arch weights and create heatmaps
for ax_id, (u, v, cell_name) in enumerate(list(alpha_dict.keys())[start_id:start_id+n_rows]):
map = sns.heatmap(
alpha_dict[u, v, cell_name]['alphas'].T,
cmap=cmap,
vmin=min_soft,
vmax=max_soft,
ax=axes[ax_id],
cbar=True
)

op_names = alpha_dict[(u, v, cell_name)]['op_names']

if ax_id < n_rows-1:
axes[ax_id].set_xticks([])
axes[ax_id].set_ylabel(f"{u, v}", fontdict=dict(fontsize=6))
axes[ax_id].set_yticks(np.arange(len(op_names)) + 0.5)
fontsize = max(6, 40/len(op_names))
axes[ax_id].set_yticklabels(op_names, rotation=360, fontdict=dict(fontsize=fontsize))
if cell_name != "":
axes[ax_id].set_title(cell_name, fontdict=dict(fontsize=6))
cbar = map.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)
cbar.ax.set_title('softmax', fontdict=dict(fontsize=6))

axes[ax_id].xaxis.set_tick_params(labelsize=6)
axes[ax_id].set_xlabel("Epoch", fontdict=dict(fontsize=6))

fig.suptitle(f"optimizer: {config.optimizer}, search space: {config.search_space}, dataset: {config.dataset}, seed: {config.seed}")
fig.tight_layout()
fig.savefig(f"{config.save}/arch_weights_{start_id+1}to{start_id+n_rows}.png", dpi=300)