Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plotting convergence of alpha weights in One-Shot Optimizers #154

Open
wants to merge 7 commits into
base: Develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/example_runner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
}
],
"source": [
"from naslib.utils import utils\n",
"config = utils.get_config_from_args(config_type='nas')\n",
"\n",
"logger = setup_logger(config.save + \"/log.log\")\n",
Expand Down
26 changes: 26 additions & 0 deletions examples/plot_darts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import logging
from naslib.defaults.trainer import Trainer
from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch
from naslib.search_spaces import DartsSearchSpace, SimpleCellSearchSpace

from naslib.utils import set_seed, setup_logger, get_config_from_args

config = get_config_from_args() # use --help so see the options
config.search.batch_size = 128
config.search.epochs = 1
config.save_arch_weights = True
config.plot_arch_weights = True
config.save_arch_weights_path = f"{config.save}/save_arch"
set_seed(config.seed)

logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO) # default DEBUG is very verbose

search_space = SimpleCellSearchSpace() # DartsSearchSpace() # use SimpleCellSearchSpace() for less heavy search

optimizer = DARTSOptimizer(config)
optimizer.adapt_search_space(search_space)

trainer = Trainer(optimizer, config)
trainer.search()
144 changes: 96 additions & 48 deletions naslib/defaults/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import torch
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Neither matplotlib.pyplot nor seaborn are used in this file. Remove.
  2. Seaborn is missing in requirements.txt.


from pathlib import Path
from fvcore.common.checkpoint import PeriodicCheckpointer

from naslib.search_spaces.core.query_metrics import Metric
Expand Down Expand Up @@ -57,8 +61,8 @@ def __init__(self, optimizer, config, lightweight_output=False):
self.val_loss = utils.AverageMeter()

n_parameters = optimizer.get_model_size()
# logger.info("param size = %fMB", n_parameters)
self.search_trajectory = utils.AttrDict(
logger.info("param size = %fMB", n_parameters)
self.errors_dict = utils.AttrDict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep the name as "search_trajectory", since it also contains non-error metrics such as accuracies, train_time and params

{
"train_acc": [],
"train_loss": [],
Expand All @@ -73,7 +77,8 @@ def __init__(self, optimizer, config, lightweight_output=False):
}
)

def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int], None]=None, report_incumbent=True):
def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int], None] = None,
report_incumbent=True):
"""
Start the architecture search.

Expand All @@ -83,7 +88,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
resume_from (str): Checkpoint file to resume from. If not given then
train from scratch.
"""
logger.info("Beginning search")
logger.info("Start training")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try and keep the changes in a PR relevant to the functionality it addresses. Reversing the commits in a PR should only affect the functionality it introduces.


np.random.seed(self.config.search.seed)
torch.manual_seed(self.config.search.seed)
Expand All @@ -108,11 +113,26 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int

for e in range(start_epoch, self.epochs):

# create the arch directory (without overwriting)
if self.config.save_arch_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better for readability to write this as if self.config.save_arch_weights is True

Path(f"{self.config.save_arch_weights_path}/epoch_{e}").mkdir(parents=True, exist_ok=False)

start_time = time.time()
self.optimizer.new_epoch(e)

arch_weights_lst = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, as a programming practice, don't mention the data-structure in the name of the variable. arch_weights.append(...) makes it clear that it is a list, not a dict.

if self.optimizer.using_step_function:
for step, data_train in enumerate(self.train_queue):

# save arch weights to array of tensors
if self.config.save_arch_weights:
if len(arch_weights_lst) == 0:
for alpha_i in self.optimizer.architectural_weights:
arch_weights_lst.append(torch.unsqueeze(alpha_i.detach(), dim=0))
else:
for idx, alpha_i in enumerate(self.optimizer.architectural_weights):
arch_weights_lst[idx] = torch.cat((arch_weights_lst[idx], torch.unsqueeze(alpha_i.detach(), dim=0)), dim=0)

data_train = (
data_train[0].to(self.device),
data_train[1].to(self.device, non_blocking=True),
Expand Down Expand Up @@ -151,11 +171,11 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int

end_time = time.time()

self.search_trajectory.train_acc.append(self.train_top1.avg)
self.search_trajectory.train_loss.append(self.train_loss.avg)
self.search_trajectory.valid_acc.append(self.val_top1.avg)
self.search_trajectory.valid_loss.append(self.val_loss.avg)
self.search_trajectory.runtime.append(end_time - start_time)
self.errors_dict.train_acc.append(self.train_top1.avg)
self.errors_dict.train_loss.append(self.train_loss.avg)
self.errors_dict.valid_acc.append(self.val_top1.avg)
self.errors_dict.valid_loss.append(self.val_loss.avg)
self.errors_dict.runtime.append(end_time - start_time)
else:
end_time = time.time()
# TODO: nasbench101 does not have train_loss, valid_loss, test_loss implemented, so this is a quick fix for now
Expand All @@ -168,28 +188,28 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
) = self.optimizer.train_statistics(report_incumbent)
train_loss, valid_loss, test_loss = -1, -1, -1

self.search_trajectory.train_acc.append(train_acc)
self.search_trajectory.train_loss.append(train_loss)
self.search_trajectory.valid_acc.append(valid_acc)
self.search_trajectory.valid_loss.append(valid_loss)
self.search_trajectory.test_acc.append(test_acc)
self.search_trajectory.test_loss.append(test_loss)
self.search_trajectory.runtime.append(end_time - start_time)
self.search_trajectory.train_time.append(train_time)
self.errors_dict.train_acc.append(train_acc)
self.errors_dict.train_loss.append(train_loss)
self.errors_dict.valid_acc.append(valid_acc)
self.errors_dict.valid_loss.append(valid_loss)
self.errors_dict.test_acc.append(test_acc)
self.errors_dict.test_loss.append(test_loss)
self.errors_dict.runtime.append(end_time - start_time)
self.errors_dict.train_time.append(train_time)
self.train_top1.avg = train_acc
self.val_top1.avg = valid_acc

self.periodic_checkpointer.step(e)

anytime_results = self.optimizer.test_statistics()
# if anytime_results:
if anytime_results:
# record anytime performance
# self.search_trajectory.arch_eval.append(anytime_results)
# log_every_n_seconds(
# logging.INFO,
# "Epoch {}, Anytime results: {}".format(e, anytime_results),
# n=5,
# )
self.errors_dict.arch_eval.append(anytime_results)
log_every_n_seconds(
logging.INFO,
"Epoch {}, Anytime results: {}".format(e, anytime_results),
n=5,
)

self._log_to_json()

Expand All @@ -198,6 +218,23 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
if after_epoch is not None:
after_epoch(e)

logger.info(f"Saving architectural weight tensors: {self.config.save_arch_weights_path}/epoch_{e}")

# writing arch weights to file and plotting
if self.config.save_arch_weights:
if not Path(f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_0.pt').exists():
for idx in range(len(arch_weights_lst)):
if self.config.plot_arch_weights:
self._plot_architectural_weights(idx, alpha_i=arch_weights_lst[idx], epoch_num=e)
torch.save(arch_weights_lst[idx], f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt')
else:
for idx in range(len(self.optimizer.architectural_weights)):
old_arch_weights = torch.load(f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt')
arch_weights_lst[idx] = torch.cat((old_arch_weights, arch_weights_lst[idx]), dim=0)
if self.config.plot_arch_weights:
self._plot_architectural_weights(idx, alpha_i=arch_weights_lst[idx], epoch_num=e)
torch.save(arch_weights_lst[idx], f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt')

self.optimizer.after_training()

if summary_writer is not None:
Expand Down Expand Up @@ -240,23 +277,23 @@ def evaluate_oneshot(self, resume_from="", dataloader=None):

end_time = time.time()

self.search_trajectory.valid_acc.append(self.val_top1.avg)
self.search_trajectory.valid_loss.append(self.val_loss.avg)
self.search_trajectory.runtime.append(end_time - start_time)
self.errors_dict.valid_acc.append(self.val_top1.avg)
self.errors_dict.valid_loss.append(self.val_loss.avg)
self.errors_dict.runtime.append(end_time - start_time)

self._log_to_json()

logger.info("Evaluation finished")
return self.val_top1.avg

def evaluate(
self,
retrain:bool=True,
search_model:str="",
resume_from:str="",
best_arch:Graph=None,
dataset_api:object=None,
metric:Metric=None,
self,
retrain: bool = True,
search_model: str = "",
resume_from: str = "",
best_arch: Graph = None,
dataset_api: object = None,
metric: Metric = None,
):
"""
Evaluate the final architecture as given from the optimizer.
Expand Down Expand Up @@ -284,7 +321,7 @@ def evaluate(
self._setup_checkpointers(search_model) # required to load the architecture

best_arch = self.optimizer.get_final_architecture()
logger.info(f"Final architecture hash: {best_arch.get_hash()}")
logger.info("Final architecture:\n" + best_arch.modules_str())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not belong in this PR


if best_arch.QUERYABLE:
if metric is None:
Expand All @@ -293,7 +330,6 @@ def evaluate(
metric=metric, dataset=self.config.dataset, dataset_api=dataset_api
)
logger.info("Queried results ({}): {}".format(metric, result))
return result
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has this line been removed?

else:
best_arch.to(self.device)
if retrain:
Expand Down Expand Up @@ -366,14 +402,14 @@ def evaluate(
logits_train = best_arch(input_train)
train_loss = loss(logits_train, target_train)
if hasattr(
best_arch, "auxilary_logits"
best_arch, "auxilary_logits"
): # darts specific stuff
log_first_n(logging.INFO, "Auxiliary is used", n=10)
auxiliary_loss = loss(
best_arch.auxilary_logits(), target_train
)
train_loss += (
self.config.evaluation.auxiliary_weight * auxiliary_loss
self.config.evaluation.auxiliary_weight * auxiliary_loss
)
train_loss.backward()
if grad_clip:
Expand All @@ -395,9 +431,8 @@ def evaluate(
if self.valid_queue:
best_arch.eval()
for i, (input_valid, target_valid) in enumerate(
self.valid_queue
self.valid_queue
):

input_valid = input_valid.to(self.device).float()
target_valid = target_valid.to(self.device).float()

Expand Down Expand Up @@ -453,8 +488,6 @@ def evaluate(
)
)

return top1.avg
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has this line been removed?


@staticmethod
def build_search_dataloaders(config):
train_queue, valid_queue, test_queue, _, _ = utils.get_train_val_loaders(
Expand Down Expand Up @@ -496,10 +529,12 @@ def build_eval_scheduler(optimizer, config):

def _log_and_reset_accuracies(self, epoch, writer=None):
logger.info(
"Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format(
"Epoch {} done. Train accuracy (top1, top5): {:.5f}, {:.5f}, Validation accuracy: {:.5f}, {:.5f}".format(
epoch,
self.train_top1.avg,
self.train_top5.avg,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is useful, but it doesn't belong with the viz PR

self.val_top1.avg,
self.val_top5.avg,
)
)

Expand Down Expand Up @@ -550,7 +585,7 @@ def _prepare_dataloaders(self, config, mode="train"):
self.test_queue = test_queue

def _setup_checkpointers(
self, resume_from="", search=True, period=1, **add_checkpointables
self, resume_from="", search=True, period=1, **add_checkpointables
):
"""
Sets up a periodic chechkpointer which can be used to save checkpoints
Expand Down Expand Up @@ -596,14 +631,27 @@ def _log_to_json(self):
os.makedirs(self.config.save)
if not self.lightweight_output:
with codecs.open(
os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8"
os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8"
) as file:
json.dump(self.search_trajectory, file, separators=(",", ":"))
json.dump(self.errors_dict, file, separators=(",", ":"))
else:
with codecs.open(
os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8"
os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8"
) as file:
lightweight_dict = copy.deepcopy(self.search_trajectory)
lightweight_dict = copy.deepcopy(self.errors_dict)
for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]:
lightweight_dict.pop(key)
json.dump([self.config, lightweight_dict], file, separators=(",", ":"))

def _plot_architectural_weights(self, idx, alpha_i, epoch_num):
# Todo check if softmax is suitable here. In which range are the weights for e.g. GDAS
alpha_i = torch.softmax(alpha_i.detach(), dim=1).cpu().numpy()
g = sns.heatmap(alpha_i.T, cmap=sns.diverging_palette(230, 0, 90, 60, as_cmap=True))
g.set_xticklabels(g.get_xticklabels(), rotation=60)

plt.title(f"arch weights for operation {idx}")
plt.xlabel("steps")
plt.ylabel("alpha values")
plt.tight_layout()
plt.savefig(f"{self.config.save_arch_weights_path}/epoch_{epoch_num}/heatmap_{idx}.png")
plt.close()