Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plotting convergence of alpha weights in One-Shot Optimizers #154

Open
wants to merge 7 commits into
base: Develop
Choose a base branch
from

Conversation

jr2021
Copy link
Collaborator

@jr2021 jr2021 commented Jan 30, 2023

Finalized Lukas' idea for heat map to visualize the convergence of alpha weights for One-Shot optimizers, via two new configuration parameters config.save_arch_weights and config.plot_arch_weights

Plotting was made to be extensible to larger search spaces by limiting the number of edges to 4, but this could also be parameterized if the user wants to be able to control this too.

@jr2021 jr2021 changed the title Dev alpha plotting Plotting convergence of alpha weights in One-Shot Optimizers Jan 30, 2023
# logger.info("param size = %fMB", n_parameters)
self.search_trajectory = utils.AttrDict(
logger.info("param size = %fMB", n_parameters)
self.errors_dict = utils.AttrDict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep the name as "search_trajectory", since it also contains non-error metrics such as accuracies, train_time and params

@@ -453,8 +488,6 @@ def evaluate(
)
)

return top1.avg
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has this line been removed?

@@ -293,7 +330,6 @@ def evaluate(
metric=metric, dataset=self.config.dataset, dataset_api=dataset_api
)
logger.info("Queried results ({}): {}".format(metric, result))
return result
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has this line been removed?

epoch,
self.train_top1.avg,
self.train_top5.avg,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is useful, but it doesn't belong with the viz PR

@@ -83,7 +88,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
resume_from (str): Checkpoint file to resume from. If not given then
train from scratch.
"""
logger.info("Beginning search")
logger.info("Start training")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try and keep the changes in a PR relevant to the functionality it addresses. Reversing the commits in a PR should only affect the functionality it introduces.

start_time = time.time()
self.optimizer.new_epoch(e)

arch_weights_lst = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, as a programming practice, don't mention the data-structure in the name of the variable. arch_weights.append(...) makes it clear that it is a list, not a dict.

@@ -108,11 +113,26 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int

for e in range(start_epoch, self.epochs):

# create the arch directory (without overwriting)
if self.config.save_arch_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better for readability to write this as if self.config.save_arch_weights is True

@@ -284,7 +321,7 @@ def evaluate(
self._setup_checkpointers(search_model) # required to load the architecture

best_arch = self.optimizer.get_final_architecture()
logger.info(f"Final architecture hash: {best_arch.get_hash()}")
logger.info("Final architecture:\n" + best_arch.modules_str())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not belong in this PR

@@ -322,8 +322,8 @@ def get_train_val_loaders(config, mode="train"):
data = config.data
dataset = config.dataset
seed = config.search.seed
batch_size = config.batch_size
train_portion = config.train_portion
batch_size = config.batch_size if hasattr(config, "batch_size") else config.search.batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a bug fix independent of the visualization code. If so, create a new PR, or push the fix directly to Develop.

@@ -8,12 +8,17 @@
import torch
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Neither matplotlib.pyplot nor seaborn are used in this file. Remove.
  2. Seaborn is missing in requirements.txt.

from naslib.optimizers import DARTSOptimizer, GDASOptimizer, DrNASOptimizer
from naslib.search_spaces import NasBench101SearchSpace, NasBench201SearchSpace, NasBench301SearchSpace

from naslib.utils import set_seed, setup_logger, get_config_from_args, create_exp_dir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried running this file and it crashed because create_exp_dir is not imported in the __init__.py of utils

@@ -198,6 +215,13 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
if after_epoch is not None:
after_epoch(e)

# save and possibly plot architectural weights
logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt")
if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.config.save_arch_weights is True for better readability

logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt")
if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights:
torch.save(arch_weights, f'{self.config.save}/arch_weights.pt')
if hasattr(self.config, "plot_arch_weights") and self.config.plot_arch_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.config.plot_arch_weights is True

all_weights = torch.load(f'{config.save}/arch_weights.pt') # load alphas

# unpack search space information
alpha_dict = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid data-structure name in var name

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import

alpha_dict = {}
min_soft, max_soft = np.inf, -np.inf
for graph in optimizer.graph._get_child_graphs(single_instances=True):
for edge_weights, (u, v, edge_data) in zip(all_weights, graph.edges.data()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a bug here. all_weights is a list of size 28 (in case of nb301), each a tensor of size (n_steps, n_operations). The first 14 are from the normal cells, while the next 14 are from the reduction cells. The loop assigns the same alphas for both normal and reduction cells in line 29.
Screenshot 2023-02-01 at 3 47 40 PM

Copy link
Collaborator

@Neonkraft Neonkraft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address the comments :)

There seems to be a small bug in the plotting code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants