Skip to content

Commit

Permalink
Improve training logger behavior (rusty1s#138)
Browse files Browse the repository at this point in the history
* Specify task from config, refine logger behavior

* lint

* update

* explain logger behavior
  • Loading branch information
JiaxuanYou authored Dec 15, 2021
1 parent e87b187 commit 599087d
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 42 deletions.
35 changes: 19 additions & 16 deletions benchmark/train/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch
import os.path as osp
import logging
import warnings
import os
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from torch_geometric import seed_everything
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.graphgym.config import (dump_cfg, set_run_dir, load_cfg,
get_fname)
from torch_geometric.graphgym.logger import set_printing, create_logger
from torch_geometric.graphgym.config import (load_cfg, get_fname)
from torch_geometric.graphgym.logger import set_printing
from torch_geometric.graphgym.utils.device import auto_select_device

from kumo.config import cfg
Expand All @@ -18,26 +18,35 @@
import kumo.custom # noqa
from kumo.utils.visualization import visualize_scalar_distribution

if __name__ == "__main__":
if __name__ == '__main__':
warnings.filterwarnings('ignore', '.*Lazy modules.*')
warnings.filterwarnings('ignore', '.*does not have many workers.*')
# Load cmd line args
args = parse_args()
# Load config file
load_cfg(cfg, args)
dump_cfg(cfg)
# Logger: create root directory for each random seed run
cfg_name = get_fname(args.cfg_file)
tb_logger = pl_loggers.TensorBoardLogger(cfg.out_dir, name=cfg_name,
sub_dir=str(cfg.seed + 1))
# Repeat for different random seeds
for i in range(args.repeat):
set_run_dir(cfg.out_dir, args.cfg_file)
set_printing()
# Set configurations for each run
cfg.seed = cfg.seed + 1
seed_everything(cfg.seed)
auto_select_device()
# Logger: create sub-directory and logger for each random seed run
tb_logger = pl_loggers.TensorBoardLogger(cfg.out_dir, name=cfg_name,
version=tb_logger.version,
sub_dir=str(cfg.seed))
cfg.run_dir = tb_logger.log_dir
set_printing()
with open(osp.join(cfg.run_dir, cfg.cfg_dest), 'w') as f:
cfg.dump(stream=f)

# Set machine learning pipeline
dataset, loaders = create_loader(cfg)
emb_size = get_emb_size(dataset)
loggers = create_logger()
warnings.filterwarnings("ignore", ".*Lazy modules.*")
model = create_model(cfg, dataset.metadata(), emb_size=emb_size)
# infer dim
model(dataset.feat_dict, dataset.edge_index_dict,
Expand All @@ -47,13 +56,7 @@
logging.info(cfg)

# Training
warnings.filterwarnings("ignore", ".*does not have many workers.*")

# Logger
tb_logger = pl_loggers.TensorBoardLogger("results/",
name=get_fname(args.cfg_file),
version=i)
gpus = 1 if cfg.device != "cpu" and torch.cuda.is_available() else 0
gpus = 1 if cfg.device != 'cpu' and torch.cuda.is_available() else 0
trainer = pl.Trainer(gpus=gpus, log_every_n_steps=1,
max_epochs=cfg.optim.max_epoch, logger=tb_logger)
# visualization
Expand Down
49 changes: 26 additions & 23 deletions kumo/custom/loader/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch_geometric.graphgym.register import register_loader

from kumo.store import Table, Store
from kumo.encoder import Index, OneHot
from kumo.encoder import Index
from kumo.connector import SnowflakeConnector


Expand All @@ -14,71 +14,73 @@ def load_imdb(format, name, dataset_dir, label_table_name, label_column_name):

cfg = {
'ACTORS': {
'primary_key':
'ACTORID',
'primary_key': 'ACTORID',
'encoders': [
('A_GENDER', OneHot(num_categories=2)),
('A_QUALITY', OneHot(num_categories=6)),
('A_GENDER', Index()),
('A_QUALITY', Index()),
],
'foreign_key_info': [],
},
'DIRECTORS': {
'primary_key':
'DIRECTORID',
'primary_key': 'DIRECTORID',
'encoders': [
('D_QUALITY', OneHot(num_categories=6)),
('AVG_REVENUE', OneHot(num_categories=5)),
('D_QUALITY', Index()),
('AVG_REVENUE', Index()),
],
'foreign_key_info': [],
},
'MOVIES': {
'primary_key':
'MOVIEID',
'encoders': [
('YEAR', OneHot(num_categories=4)),
('ISENGLISH', OneHot(num_categories=2)),
('COUNTRY', OneHot(num_categories=4)),
('RUNNINGTIME', OneHot(num_categories=4)),
('YEAR', Index()),
('ISENGLISH', Index()),
('COUNTRY', Index()),
('RUNNINGTIME', Index()),
],
'foreign_key_info': [],
},
'MOVIES2ACTORS': {
'primary_key': None,
'encoders': [
('CAST_NUM', OneHot(num_categories=4)),
('CAST_NUM', Index()),
],
'foreign_key_info': [('MOVIEID', 'MOVIES'), ('ACTORID', 'ACTORS')],
},
'MOVIES2DIRECTORS': {
'primary_key':
None,
'encoders': [
('GENRE', OneHot(num_categories=9)),
('GENRE', Index()),
],
'foreign_key_info': [('MOVIEID', 'MOVIES'),
('DIRECTORID', 'DIRECTORS')],
},
'U2BASE': {
'primary_key': None,
'encoders': [
('RATING', OneHot(num_categories=5)),
('RATING', Index()),
],
'foreign_key_info': [('USERID', 'USERS'), ('MOVIEID', 'MOVIES')],
'label_encoder': ('RATING', Index()),
},
'USERS': {
'primary_key':
'USERID',
'encoders': [
('AGE', OneHot(num_categories=7)),
('U_GENDER', OneHot(num_categories=2)),
('OCCUPATION', OneHot(num_categories=5)),
('AGE', Index()),
('U_GENDER', Index()),
('OCCUPATION', Index()),
],
'foreign_key_info': [],
# 'label_encoder': ('U_GENDER', Index()),
},
}

# set label column
# TODO: adapt IMDB with new loader format
encoders = cfg[label_table_name]['encoders']
cfg[label_table_name]['label_encoder'] = \
[encoder for encoder in encoders if label_column_name == encoder[0]][0]

if format == 'snowflake':

connector = SnowflakeConnector(
Expand All @@ -101,8 +103,9 @@ def load_imdb(format, name, dataset_dir, label_table_name, label_column_name):
else:
return None

# TODO: Temporary fix when a table has no node feature
data['U2BASE'].feat = torch.ones(data['U2BASE'].feat.shape[0], 1)
# Todo: temporary fix when a table has no node feature
if data['U2BASE'].feat.shape[1] == 0:
data['U2BASE'].feat = torch.ones(data['U2BASE'].feat.shape[0], 1)

return data

Expand Down
2 changes: 1 addition & 1 deletion kumo/model/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _get_pred_int(self, pred_score):
def _eval_auroc(self, label, pred):
# if pred_value is multi-class classification, compute softmax
if (self.cfg.dataset.task_type == "classification"
and (pred.dim() == 1 or pred.size(0) == pred.numel())):
and pred.dim() > 1 and pred.size(1) > 2):
# TODO: add pred = nn.Softmax(dim=1)(pred) and "ovo" for
# multi-class if imbalance eval is needed via auc
return 0
Expand Down
6 changes: 4 additions & 2 deletions kumo/train/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ def load_dataset(cfg: CfgNode, **kwargs):
format = cfg.dataset.format
name = cfg.dataset.name
dataset_dir = osp.join(cfg.dataset.dir, name)
label_table = cfg.dataset.label_table
label_column = cfg.dataset.label_column
# Try to load customized data format
for func in register.loader_dict.values():
dataset = func(
format,
name,
dataset_dir,
cfg.dataset.label_table,
cfg.dataset.label_column,
label_table,
label_column,
)
if dataset is not None:
return dataset
Expand Down

0 comments on commit 599087d

Please sign in to comment.