-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Plotting convergence of alpha weights in One-Shot Optimizers #154
base: Develop
Are you sure you want to change the base?
Changes from all commits
77e6191
d6fe57a
a7fba22
6ca93d9
d3d8300
0082349
3d689cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
import logging | ||
from naslib.defaults.trainer import Trainer | ||
from naslib.optimizers import DARTSOptimizer, GDASOptimizer, DrNASOptimizer | ||
from naslib.search_spaces import NasBench101SearchSpace, NasBench201SearchSpace, NasBench301SearchSpace | ||
|
||
from naslib.utils import set_seed, setup_logger, get_config_from_args, create_exp_dir | ||
from naslib.utils.vis import plot_architectural_weights | ||
|
||
config = get_config_from_args() # use --help so see the options | ||
config.search.epochs = 50 | ||
config.save_arch_weights = True | ||
config.plot_arch_weights = True | ||
config.optimizer = 'gdas' | ||
config.search_space = 'nasbench301' | ||
config.save = "{}/{}/{}/{}/{}".format( | ||
config.out_dir, config.search_space, config.dataset, config.optimizer, config.seed | ||
) | ||
create_exp_dir(config.save) | ||
create_exp_dir(config.save + "/search") # required for the checkpoints | ||
create_exp_dir(config.save + "/eval") | ||
|
||
optimizers = { | ||
'gdas': GDASOptimizer(config), | ||
'darts': DARTSOptimizer(config), | ||
'drnas': DrNASOptimizer(config), | ||
} | ||
|
||
search_spaces = { | ||
'nasbench101': NasBench101SearchSpace(), | ||
'nasbench201': NasBench201SearchSpace(), | ||
'nasbench301': NasBench301SearchSpace(), | ||
} | ||
|
||
set_seed(config.seed) | ||
|
||
logger = setup_logger(config.save + "/log.log") | ||
logger.setLevel(logging.INFO) # default DEBUG is very verbose | ||
|
||
search_space = search_spaces[config.search_space] | ||
|
||
optimizer = optimizers[config.optimizer] | ||
optimizer.adapt_search_space(search_space) | ||
|
||
trainer = Trainer(optimizer, config) | ||
# trainer.search() | ||
|
||
plot_architectural_weights(config, optimizer) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,12 +8,14 @@ | |
import torch | ||
import numpy as np | ||
|
||
from pathlib import Path | ||
from fvcore.common.checkpoint import PeriodicCheckpointer | ||
|
||
from naslib.search_spaces.core.query_metrics import Metric | ||
|
||
from naslib.utils import utils | ||
from naslib.utils.logging import log_every_n_seconds, log_first_n | ||
from naslib.utils.vis import plot_architectural_weights | ||
|
||
from typing import Callable | ||
from .additional_primitives import DropPathWrapper | ||
|
@@ -57,7 +59,7 @@ def __init__(self, optimizer, config, lightweight_output=False): | |
self.val_loss = utils.AverageMeter() | ||
|
||
n_parameters = optimizer.get_model_size() | ||
# logger.info("param size = %fMB", n_parameters) | ||
logger.info("param size = %fMB", n_parameters) | ||
self.search_trajectory = utils.AttrDict( | ||
{ | ||
"train_acc": [], | ||
|
@@ -106,13 +108,24 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
self.config | ||
) | ||
|
||
arch_weights = [] | ||
for e in range(start_epoch, self.epochs): | ||
|
||
start_time = time.time() | ||
self.optimizer.new_epoch(e) | ||
|
||
if self.optimizer.using_step_function: | ||
for step, data_train in enumerate(self.train_queue): | ||
|
||
# save arch weights to array of tensors | ||
if self.config.save_arch_weights is True: | ||
if len(arch_weights) == 0: | ||
for edge_weights in self.optimizer.architectural_weights: | ||
arch_weights.append(torch.unsqueeze(edge_weights.detach(), dim=0)) | ||
else: | ||
for i, edge_weights in enumerate(self.optimizer.architectural_weights): | ||
arch_weights[i] = torch.cat((arch_weights[i], torch.unsqueeze(edge_weights.detach(), dim=0)), dim=0) | ||
|
||
data_train = ( | ||
data_train[0].to(self.device), | ||
data_train[1].to(self.device, non_blocking=True), | ||
|
@@ -183,13 +196,13 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
|
||
anytime_results = self.optimizer.test_statistics() | ||
# if anytime_results: | ||
# record anytime performance | ||
# self.search_trajectory.arch_eval.append(anytime_results) | ||
# log_every_n_seconds( | ||
# logging.INFO, | ||
# "Epoch {}, Anytime results: {}".format(e, anytime_results), | ||
# n=5, | ||
# ) | ||
# # record anytime performance | ||
# self.search_trajectory.arch_eval.append(anytime_results) | ||
# log_every_n_seconds( | ||
# logging.INFO, | ||
# "Epoch {}, Anytime results: {}".format(e, anytime_results), | ||
# n=5, | ||
# ) | ||
|
||
self._log_to_json() | ||
|
||
|
@@ -198,6 +211,13 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
if after_epoch is not None: | ||
after_epoch(e) | ||
|
||
# save and possibly plot architectural weights | ||
logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt") | ||
if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
torch.save(arch_weights, f'{self.config.save}/arch_weights.pt') | ||
if hasattr(self.config, "plot_arch_weights") and self.config.plot_arch_weights: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
plot_architectural_weights(self.config, self.optimizer) | ||
|
||
self.optimizer.after_training() | ||
|
||
if summary_writer is not None: | ||
|
@@ -366,7 +386,7 @@ def evaluate( | |
logits_train = best_arch(input_train) | ||
train_loss = loss(logits_train, target_train) | ||
if hasattr( | ||
best_arch, "auxilary_logits" | ||
best_arch, "auxilary_logits" | ||
): # darts specific stuff | ||
log_first_n(logging.INFO, "Auxiliary is used", n=10) | ||
auxiliary_loss = loss( | ||
|
@@ -455,6 +475,7 @@ def evaluate( | |
|
||
return top1.avg | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why has this line been removed? |
||
|
||
|
||
@staticmethod | ||
def build_search_dataloaders(config): | ||
train_queue, valid_queue, test_queue, _, _ = utils.get_train_val_loaders( | ||
|
@@ -496,10 +517,10 @@ def build_eval_scheduler(optimizer, config): | |
|
||
def _log_and_reset_accuracies(self, epoch, writer=None): | ||
logger.info( | ||
"Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format( | ||
"Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format( | ||
epoch, | ||
self.train_top1.avg, | ||
self.val_top1.avg, | ||
self.val_top1.avg | ||
) | ||
) | ||
|
||
|
@@ -606,4 +627,4 @@ def _log_to_json(self): | |
lightweight_dict = copy.deepcopy(self.search_trajectory) | ||
for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]: | ||
lightweight_dict.pop(key) | ||
json.dump([self.config, lightweight_dict], file, separators=(",", ":")) | ||
json.dump([self.config, lightweight_dict], file, separators=(",", ":")) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -322,8 +322,8 @@ def get_train_val_loaders(config, mode="train"): | |
data = config.data | ||
dataset = config.dataset | ||
seed = config.search.seed | ||
batch_size = config.batch_size | ||
train_portion = config.train_portion | ||
batch_size = config.batch_size if hasattr(config, "batch_size") else config.search.batch_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a bug fix independent of the visualization code. If so, create a new PR, or push the fix directly to Develop. |
||
train_portion = config.train_portion if hasattr(config, "train_portion") else config.search.train_portion | ||
config = config.search if mode == "train" else config.evaluation | ||
if dataset == "cifar10": | ||
train_transform, valid_transform = _data_transforms_cifar10(config) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .utils import ( | ||
plot_architectural_weights, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import logging | ||
import torch | ||
import numpy as np | ||
|
||
import matplotlib.pyplot as plt | ||
from matplotlib.cm import ScalarMappable | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused import |
||
import seaborn as sns | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def plot_architectural_weights(config, optimizer): | ||
all_weights = torch.load(f'{config.save}/arch_weights.pt') # load alphas | ||
|
||
# unpack search space information | ||
alpha_dict = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid data-structure name in var name |
||
min_soft, max_soft = np.inf, -np.inf | ||
for graph in optimizer.graph._get_child_graphs(single_instances=True): | ||
for edge_weights, (u, v, edge_data) in zip(all_weights, graph.edges.data()): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
if edge_data.has("alpha"): | ||
total_steps, num_alphas = edge_weights.shape | ||
steps_per_epoch = total_steps // config.search.epochs | ||
disc_weights = torch.mean(edge_weights.detach().reshape(-1, steps_per_epoch, num_alphas), axis=1).cpu() | ||
soft_weights = torch.softmax(disc_weights, dim=-1).numpy() | ||
|
||
cell_name = edge_data['cell_name'] if hasattr(edge_data, 'cell_name') else "" | ||
alpha_dict[(u, v, cell_name)] = {} | ||
alpha_dict[(u, v, cell_name)]['op_names'] = [op.get_op_name for op in edge_data.op.get_embedded_ops()] | ||
alpha_dict[(u, v, cell_name)]['alphas'] = soft_weights | ||
|
||
min_soft = min(min_soft, np.min(soft_weights)) | ||
max_soft = max(max_soft, np.max(soft_weights)) | ||
|
||
max_rows = 4 # plot heatmaps in increments of n_rows edges | ||
for start_id in range(0, len(alpha_dict.keys()), max_rows): | ||
|
||
# calculate number of rows in plot | ||
n_rows = min(max_rows, len(alpha_dict.keys())-start_id) | ||
logger.info(f"Creating plot {config.save}/arch_weights_{start_id+1}to{start_id+n_rows}.png") | ||
|
||
# define figure and axes and NASLib colormap | ||
fig, axes = plt.subplots(nrows=n_rows, figsize=(10, max_rows)) | ||
cmap = sns.diverging_palette(230, 0, 90, 60, as_cmap=True) | ||
|
||
# iterate over arch weights and create heatmaps | ||
for ax_id, (u, v, cell_name) in enumerate(list(alpha_dict.keys())[start_id:start_id+n_rows]): | ||
map = sns.heatmap( | ||
alpha_dict[u, v, cell_name]['alphas'].T, | ||
cmap=cmap, | ||
vmin=min_soft, | ||
vmax=max_soft, | ||
ax=axes[ax_id], | ||
cbar=True | ||
) | ||
|
||
op_names = alpha_dict[(u, v, cell_name)]['op_names'] | ||
|
||
if ax_id < n_rows-1: | ||
axes[ax_id].set_xticks([]) | ||
axes[ax_id].set_ylabel(f"{u, v}", fontdict=dict(fontsize=6)) | ||
axes[ax_id].set_yticks(np.arange(len(op_names)) + 0.5) | ||
fontsize = max(6, 40/len(op_names)) | ||
axes[ax_id].set_yticklabels(op_names, rotation=360, fontdict=dict(fontsize=fontsize)) | ||
if cell_name != "": | ||
axes[ax_id].set_title(cell_name, fontdict=dict(fontsize=6)) | ||
cbar = map.collections[0].colorbar | ||
cbar.ax.tick_params(labelsize=6) | ||
cbar.ax.set_title('softmax', fontdict=dict(fontsize=6)) | ||
|
||
axes[ax_id].xaxis.set_tick_params(labelsize=6) | ||
axes[ax_id].set_xlabel("Epoch", fontdict=dict(fontsize=6)) | ||
|
||
fig.suptitle(f"optimizer: {config.optimizer}, search space: {config.search_space}, dataset: {config.dataset}, seed: {config.seed}") | ||
fig.tight_layout() | ||
fig.savefig(f"{config.save}/arch_weights_{start_id+1}to{start_id+n_rows}.png", dpi=300) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried running this file and it crashed because
create_exp_dir
is not imported in the__init__.py
ofutils