-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Plotting convergence of alpha weights in One-Shot Optimizers #154
base: Develop
Are you sure you want to change the base?
Conversation
naslib/defaults/trainer.py
Outdated
# logger.info("param size = %fMB", n_parameters) | ||
self.search_trajectory = utils.AttrDict( | ||
logger.info("param size = %fMB", n_parameters) | ||
self.errors_dict = utils.AttrDict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would keep the name as "search_trajectory", since it also contains non-error metrics such as accuracies, train_time and params
@@ -453,8 +488,6 @@ def evaluate( | |||
) | |||
) | |||
|
|||
return top1.avg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why has this line been removed?
naslib/defaults/trainer.py
Outdated
@@ -293,7 +330,6 @@ def evaluate( | |||
metric=metric, dataset=self.config.dataset, dataset_api=dataset_api | |||
) | |||
logger.info("Queried results ({}): {}".format(metric, result)) | |||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why has this line been removed?
naslib/defaults/trainer.py
Outdated
epoch, | ||
self.train_top1.avg, | ||
self.train_top5.avg, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is useful, but it doesn't belong with the viz PR
naslib/defaults/trainer.py
Outdated
@@ -83,7 +88,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |||
resume_from (str): Checkpoint file to resume from. If not given then | |||
train from scratch. | |||
""" | |||
logger.info("Beginning search") | |||
logger.info("Start training") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try and keep the changes in a PR relevant to the functionality it addresses. Reversing the commits in a PR should only affect the functionality it introduces.
naslib/defaults/trainer.py
Outdated
start_time = time.time() | ||
self.optimizer.new_epoch(e) | ||
|
||
arch_weights_lst = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, as a programming practice, don't mention the data-structure in the name of the variable. arch_weights.append(...)
makes it clear that it is a list, not a dict.
naslib/defaults/trainer.py
Outdated
@@ -108,11 +113,26 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |||
|
|||
for e in range(start_epoch, self.epochs): | |||
|
|||
# create the arch directory (without overwriting) | |||
if self.config.save_arch_weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better for readability to write this as if self.config.save_arch_weights is True
naslib/defaults/trainer.py
Outdated
@@ -284,7 +321,7 @@ def evaluate( | |||
self._setup_checkpointers(search_model) # required to load the architecture | |||
|
|||
best_arch = self.optimizer.get_final_architecture() | |||
logger.info(f"Final architecture hash: {best_arch.get_hash()}") | |||
logger.info("Final architecture:\n" + best_arch.modules_str()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does not belong in this PR
@@ -322,8 +322,8 @@ def get_train_val_loaders(config, mode="train"): | |||
data = config.data | |||
dataset = config.dataset | |||
seed = config.search.seed | |||
batch_size = config.batch_size | |||
train_portion = config.train_portion | |||
batch_size = config.batch_size if hasattr(config, "batch_size") else config.search.batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a bug fix independent of the visualization code. If so, create a new PR, or push the fix directly to Develop.
naslib/defaults/trainer.py
Outdated
@@ -8,12 +8,17 @@ | |||
import torch | |||
import numpy as np | |||
|
|||
import matplotlib.pyplot as plt | |||
import seaborn as sns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Neither matplotlib.pyplot nor seaborn are used in this file. Remove.
- Seaborn is missing in requirements.txt.
from naslib.optimizers import DARTSOptimizer, GDASOptimizer, DrNASOptimizer | ||
from naslib.search_spaces import NasBench101SearchSpace, NasBench201SearchSpace, NasBench301SearchSpace | ||
|
||
from naslib.utils import set_seed, setup_logger, get_config_from_args, create_exp_dir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried running this file and it crashed because create_exp_dir
is not imported in the __init__.py
of utils
@@ -198,6 +215,13 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |||
if after_epoch is not None: | |||
after_epoch(e) | |||
|
|||
# save and possibly plot architectural weights | |||
logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt") | |||
if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.config.save_arch_weights is True
for better readability
logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt") | ||
if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights: | ||
torch.save(arch_weights, f'{self.config.save}/arch_weights.pt') | ||
if hasattr(self.config, "plot_arch_weights") and self.config.plot_arch_weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.config.plot_arch_weights is True
all_weights = torch.load(f'{config.save}/arch_weights.pt') # load alphas | ||
|
||
# unpack search space information | ||
alpha_dict = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid data-structure name in var name
import numpy as np | ||
|
||
import matplotlib.pyplot as plt | ||
from matplotlib.cm import ScalarMappable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused import
alpha_dict = {} | ||
min_soft, max_soft = np.inf, -np.inf | ||
for graph in optimizer.graph._get_child_graphs(single_instances=True): | ||
for edge_weights, (u, v, edge_data) in zip(all_weights, graph.edges.data()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please address the comments :)
There seems to be a small bug in the plotting code.
Finalized Lukas' idea for heat map to visualize the convergence of alpha weights for One-Shot optimizers, via two new configuration parameters
config.save_arch_weights
andconfig.plot_arch_weights
Plotting was made to be extensible to larger search spaces by limiting the number of edges to 4, but this could also be parameterized if the user wants to be able to control this too.