-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Plotting convergence of alpha weights in One-Shot Optimizers #154
base: Develop
Are you sure you want to change the base?
Changes from 3 commits
77e6191
d6fe57a
a7fba22
6ca93d9
d3d8300
0082349
3d689cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
import logging | ||
from naslib.defaults.trainer import Trainer | ||
from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch | ||
from naslib.search_spaces import DartsSearchSpace, SimpleCellSearchSpace | ||
|
||
from naslib.utils import set_seed, setup_logger, get_config_from_args | ||
|
||
config = get_config_from_args() # use --help so see the options | ||
config.search.batch_size = 128 | ||
config.search.epochs = 1 | ||
config.save_arch_weights = True | ||
config.plot_arch_weights = True | ||
config.save_arch_weights_path = f"{config.save}/save_arch" | ||
set_seed(config.seed) | ||
|
||
logger = setup_logger(config.save + "/log.log") | ||
logger.setLevel(logging.INFO) # default DEBUG is very verbose | ||
|
||
search_space = SimpleCellSearchSpace() # DartsSearchSpace() # use SimpleCellSearchSpace() for less heavy search | ||
|
||
optimizer = DARTSOptimizer(config) | ||
optimizer.adapt_search_space(search_space) | ||
|
||
trainer = Trainer(optimizer, config) | ||
trainer.search() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,10 @@ | |
import torch | ||
import numpy as np | ||
|
||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
|
||
from pathlib import Path | ||
from fvcore.common.checkpoint import PeriodicCheckpointer | ||
|
||
from naslib.search_spaces.core.query_metrics import Metric | ||
|
@@ -57,8 +61,8 @@ def __init__(self, optimizer, config, lightweight_output=False): | |
self.val_loss = utils.AverageMeter() | ||
|
||
n_parameters = optimizer.get_model_size() | ||
# logger.info("param size = %fMB", n_parameters) | ||
self.search_trajectory = utils.AttrDict( | ||
logger.info("param size = %fMB", n_parameters) | ||
self.errors_dict = utils.AttrDict( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would keep the name as "search_trajectory", since it also contains non-error metrics such as accuracies, train_time and params |
||
{ | ||
"train_acc": [], | ||
"train_loss": [], | ||
|
@@ -73,7 +77,8 @@ def __init__(self, optimizer, config, lightweight_output=False): | |
} | ||
) | ||
|
||
def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int], None]=None, report_incumbent=True): | ||
def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int], None] = None, | ||
report_incumbent=True): | ||
""" | ||
Start the architecture search. | ||
|
||
|
@@ -83,7 +88,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
resume_from (str): Checkpoint file to resume from. If not given then | ||
train from scratch. | ||
""" | ||
logger.info("Beginning search") | ||
logger.info("Start training") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try and keep the changes in a PR relevant to the functionality it addresses. Reversing the commits in a PR should only affect the functionality it introduces. |
||
|
||
np.random.seed(self.config.search.seed) | ||
torch.manual_seed(self.config.search.seed) | ||
|
@@ -108,11 +113,26 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
|
||
for e in range(start_epoch, self.epochs): | ||
|
||
# create the arch directory (without overwriting) | ||
if self.config.save_arch_weights: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is better for readability to write this as |
||
Path(f"{self.config.save_arch_weights_path}/epoch_{e}").mkdir(parents=True, exist_ok=False) | ||
|
||
start_time = time.time() | ||
self.optimizer.new_epoch(e) | ||
|
||
arch_weights_lst = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generally, as a programming practice, don't mention the data-structure in the name of the variable. |
||
if self.optimizer.using_step_function: | ||
for step, data_train in enumerate(self.train_queue): | ||
|
||
# save arch weights to array of tensors | ||
if self.config.save_arch_weights: | ||
if len(arch_weights_lst) == 0: | ||
for alpha_i in self.optimizer.architectural_weights: | ||
arch_weights_lst.append(torch.unsqueeze(alpha_i.detach(), dim=0)) | ||
else: | ||
for idx, alpha_i in enumerate(self.optimizer.architectural_weights): | ||
arch_weights_lst[idx] = torch.cat((arch_weights_lst[idx], torch.unsqueeze(alpha_i.detach(), dim=0)), dim=0) | ||
|
||
data_train = ( | ||
data_train[0].to(self.device), | ||
data_train[1].to(self.device, non_blocking=True), | ||
|
@@ -151,11 +171,11 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
|
||
end_time = time.time() | ||
|
||
self.search_trajectory.train_acc.append(self.train_top1.avg) | ||
self.search_trajectory.train_loss.append(self.train_loss.avg) | ||
self.search_trajectory.valid_acc.append(self.val_top1.avg) | ||
self.search_trajectory.valid_loss.append(self.val_loss.avg) | ||
self.search_trajectory.runtime.append(end_time - start_time) | ||
self.errors_dict.train_acc.append(self.train_top1.avg) | ||
self.errors_dict.train_loss.append(self.train_loss.avg) | ||
self.errors_dict.valid_acc.append(self.val_top1.avg) | ||
self.errors_dict.valid_loss.append(self.val_loss.avg) | ||
self.errors_dict.runtime.append(end_time - start_time) | ||
else: | ||
end_time = time.time() | ||
# TODO: nasbench101 does not have train_loss, valid_loss, test_loss implemented, so this is a quick fix for now | ||
|
@@ -168,28 +188,28 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
) = self.optimizer.train_statistics(report_incumbent) | ||
train_loss, valid_loss, test_loss = -1, -1, -1 | ||
|
||
self.search_trajectory.train_acc.append(train_acc) | ||
self.search_trajectory.train_loss.append(train_loss) | ||
self.search_trajectory.valid_acc.append(valid_acc) | ||
self.search_trajectory.valid_loss.append(valid_loss) | ||
self.search_trajectory.test_acc.append(test_acc) | ||
self.search_trajectory.test_loss.append(test_loss) | ||
self.search_trajectory.runtime.append(end_time - start_time) | ||
self.search_trajectory.train_time.append(train_time) | ||
self.errors_dict.train_acc.append(train_acc) | ||
self.errors_dict.train_loss.append(train_loss) | ||
self.errors_dict.valid_acc.append(valid_acc) | ||
self.errors_dict.valid_loss.append(valid_loss) | ||
self.errors_dict.test_acc.append(test_acc) | ||
self.errors_dict.test_loss.append(test_loss) | ||
self.errors_dict.runtime.append(end_time - start_time) | ||
self.errors_dict.train_time.append(train_time) | ||
self.train_top1.avg = train_acc | ||
self.val_top1.avg = valid_acc | ||
|
||
self.periodic_checkpointer.step(e) | ||
|
||
anytime_results = self.optimizer.test_statistics() | ||
# if anytime_results: | ||
if anytime_results: | ||
# record anytime performance | ||
# self.search_trajectory.arch_eval.append(anytime_results) | ||
# log_every_n_seconds( | ||
# logging.INFO, | ||
# "Epoch {}, Anytime results: {}".format(e, anytime_results), | ||
# n=5, | ||
# ) | ||
self.errors_dict.arch_eval.append(anytime_results) | ||
log_every_n_seconds( | ||
logging.INFO, | ||
"Epoch {}, Anytime results: {}".format(e, anytime_results), | ||
n=5, | ||
) | ||
|
||
self._log_to_json() | ||
|
||
|
@@ -198,6 +218,23 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
if after_epoch is not None: | ||
after_epoch(e) | ||
|
||
logger.info(f"Saving architectural weight tensors: {self.config.save_arch_weights_path}/epoch_{e}") | ||
|
||
# writing arch weights to file and plotting | ||
if self.config.save_arch_weights: | ||
if not Path(f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_0.pt').exists(): | ||
for idx in range(len(arch_weights_lst)): | ||
if self.config.plot_arch_weights: | ||
self._plot_architectural_weights(idx, alpha_i=arch_weights_lst[idx], epoch_num=e) | ||
torch.save(arch_weights_lst[idx], f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') | ||
else: | ||
for idx in range(len(self.optimizer.architectural_weights)): | ||
old_arch_weights = torch.load(f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') | ||
arch_weights_lst[idx] = torch.cat((old_arch_weights, arch_weights_lst[idx]), dim=0) | ||
if self.config.plot_arch_weights: | ||
self._plot_architectural_weights(idx, alpha_i=arch_weights_lst[idx], epoch_num=e) | ||
torch.save(arch_weights_lst[idx], f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') | ||
|
||
self.optimizer.after_training() | ||
|
||
if summary_writer is not None: | ||
|
@@ -240,23 +277,23 @@ def evaluate_oneshot(self, resume_from="", dataloader=None): | |
|
||
end_time = time.time() | ||
|
||
self.search_trajectory.valid_acc.append(self.val_top1.avg) | ||
self.search_trajectory.valid_loss.append(self.val_loss.avg) | ||
self.search_trajectory.runtime.append(end_time - start_time) | ||
self.errors_dict.valid_acc.append(self.val_top1.avg) | ||
self.errors_dict.valid_loss.append(self.val_loss.avg) | ||
self.errors_dict.runtime.append(end_time - start_time) | ||
|
||
self._log_to_json() | ||
|
||
logger.info("Evaluation finished") | ||
return self.val_top1.avg | ||
|
||
def evaluate( | ||
self, | ||
retrain:bool=True, | ||
search_model:str="", | ||
resume_from:str="", | ||
best_arch:Graph=None, | ||
dataset_api:object=None, | ||
metric:Metric=None, | ||
self, | ||
retrain: bool = True, | ||
search_model: str = "", | ||
resume_from: str = "", | ||
best_arch: Graph = None, | ||
dataset_api: object = None, | ||
metric: Metric = None, | ||
): | ||
""" | ||
Evaluate the final architecture as given from the optimizer. | ||
|
@@ -284,7 +321,7 @@ def evaluate( | |
self._setup_checkpointers(search_model) # required to load the architecture | ||
|
||
best_arch = self.optimizer.get_final_architecture() | ||
logger.info(f"Final architecture hash: {best_arch.get_hash()}") | ||
logger.info("Final architecture:\n" + best_arch.modules_str()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does not belong in this PR |
||
|
||
if best_arch.QUERYABLE: | ||
if metric is None: | ||
|
@@ -293,7 +330,6 @@ def evaluate( | |
metric=metric, dataset=self.config.dataset, dataset_api=dataset_api | ||
) | ||
logger.info("Queried results ({}): {}".format(metric, result)) | ||
return result | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why has this line been removed? |
||
else: | ||
best_arch.to(self.device) | ||
if retrain: | ||
|
@@ -366,14 +402,14 @@ def evaluate( | |
logits_train = best_arch(input_train) | ||
train_loss = loss(logits_train, target_train) | ||
if hasattr( | ||
best_arch, "auxilary_logits" | ||
best_arch, "auxilary_logits" | ||
): # darts specific stuff | ||
log_first_n(logging.INFO, "Auxiliary is used", n=10) | ||
auxiliary_loss = loss( | ||
best_arch.auxilary_logits(), target_train | ||
) | ||
train_loss += ( | ||
self.config.evaluation.auxiliary_weight * auxiliary_loss | ||
self.config.evaluation.auxiliary_weight * auxiliary_loss | ||
) | ||
train_loss.backward() | ||
if grad_clip: | ||
|
@@ -395,9 +431,8 @@ def evaluate( | |
if self.valid_queue: | ||
best_arch.eval() | ||
for i, (input_valid, target_valid) in enumerate( | ||
self.valid_queue | ||
self.valid_queue | ||
): | ||
|
||
input_valid = input_valid.to(self.device).float() | ||
target_valid = target_valid.to(self.device).float() | ||
|
||
|
@@ -453,8 +488,6 @@ def evaluate( | |
) | ||
) | ||
|
||
return top1.avg | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why has this line been removed? |
||
|
||
@staticmethod | ||
def build_search_dataloaders(config): | ||
train_queue, valid_queue, test_queue, _, _ = utils.get_train_val_loaders( | ||
|
@@ -496,10 +529,12 @@ def build_eval_scheduler(optimizer, config): | |
|
||
def _log_and_reset_accuracies(self, epoch, writer=None): | ||
logger.info( | ||
"Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format( | ||
"Epoch {} done. Train accuracy (top1, top5): {:.5f}, {:.5f}, Validation accuracy: {:.5f}, {:.5f}".format( | ||
epoch, | ||
self.train_top1.avg, | ||
self.train_top5.avg, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is useful, but it doesn't belong with the viz PR |
||
self.val_top1.avg, | ||
self.val_top5.avg, | ||
) | ||
) | ||
|
||
|
@@ -550,7 +585,7 @@ def _prepare_dataloaders(self, config, mode="train"): | |
self.test_queue = test_queue | ||
|
||
def _setup_checkpointers( | ||
self, resume_from="", search=True, period=1, **add_checkpointables | ||
self, resume_from="", search=True, period=1, **add_checkpointables | ||
): | ||
""" | ||
Sets up a periodic chechkpointer which can be used to save checkpoints | ||
|
@@ -596,14 +631,27 @@ def _log_to_json(self): | |
os.makedirs(self.config.save) | ||
if not self.lightweight_output: | ||
with codecs.open( | ||
os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" | ||
os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" | ||
) as file: | ||
json.dump(self.search_trajectory, file, separators=(",", ":")) | ||
json.dump(self.errors_dict, file, separators=(",", ":")) | ||
else: | ||
with codecs.open( | ||
os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" | ||
os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" | ||
) as file: | ||
lightweight_dict = copy.deepcopy(self.search_trajectory) | ||
lightweight_dict = copy.deepcopy(self.errors_dict) | ||
for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]: | ||
lightweight_dict.pop(key) | ||
json.dump([self.config, lightweight_dict], file, separators=(",", ":")) | ||
|
||
def _plot_architectural_weights(self, idx, alpha_i, epoch_num): | ||
# Todo check if softmax is suitable here. In which range are the weights for e.g. GDAS | ||
alpha_i = torch.softmax(alpha_i.detach(), dim=1).cpu().numpy() | ||
g = sns.heatmap(alpha_i.T, cmap=sns.diverging_palette(230, 0, 90, 60, as_cmap=True)) | ||
g.set_xticklabels(g.get_xticklabels(), rotation=60) | ||
|
||
plt.title(f"arch weights for operation {idx}") | ||
plt.xlabel("steps") | ||
plt.ylabel("alpha values") | ||
plt.tight_layout() | ||
plt.savefig(f"{self.config.save_arch_weights_path}/epoch_{epoch_num}/heatmap_{idx}.png") | ||
plt.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.