Skip to content

Commit

Permalink
upload
Browse files Browse the repository at this point in the history
  • Loading branch information
AllenWrong committed Jan 16, 2024
1 parent a977dda commit efcd78a
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 0 deletions.
24 changes: 24 additions & 0 deletions logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import logging


def setup_logger(name, log_file, level=logging.INFO):
"""Function to setup as many loggers as you want"""

# 创建一个logger
logger = logging.getLogger(name)
logger.setLevel(level)

# 创建用于写入日志文件的handler
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s | %(message)s'))

# 创建用于输出到控制台的handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s | %(message)s'))

# 添加handlers到logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)

return logger

41 changes: 41 additions & 0 deletions metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from torchmetrics import AUROC, Accuracy, Precision, Recall, F1Score
from enum import Enum


class Metrics(Enum):
acc = 0
auc = 1
precision = 2
recall = 3
f1 = 4


def get_metric_dict(num_classes, metrics):
"""
Args:
num_classes: number of classes
metrics_dict: `{Metrics.acc: None, Metrics.auc: None,....}`
Return:
`{Metrics.acc: Accuracy(num_classes=num_classes, average='macro'),...}`
"""
task_type = 'binary' if num_classes == 2 else 'multiclass'
metrics_dict = {}

for key in metrics:
if key == Metrics.acc:
metrics_dict[key] = Accuracy(task=task_type, average='macro', num_classes=num_classes)
elif key == Metrics.auc:
metrics_dict[key] = AUROC(num_classes=num_classes, task=task_type)
elif key == Metrics.precision:
metrics_dict[key] = Precision(num_classes=num_classes, average='macro', task=task_type)
elif key == Metrics.recall:
metrics_dict[key] = Recall(num_classes=num_classes, average='macro', task=task_type)
elif key == Metrics.f1:
metrics_dict[key] = F1Score(num_classes=num_classes, average='macro', task=task_type)
else:
raise ValueError(f"unsupported metric name {key.name}")

return metrics_dict


metrics = get_metric_dict(3, [Metrics.acc, Metrics.auc, Metrics.precision, Metrics.recall, Metrics.f1])
14 changes: 14 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
python demo.py \
--seed 42 \
--epoch 5 \
--batch_size 128 \
--device 'cuda:2' \
--lr 1e-3 \
--weight_decay 0.0 \
--data_path './data/feature/' \
--ckp_path './ckps/resume1' \
--use_sched_ckp false \
--out_ckp_path './ckps/resume2' \
--log_dir './log/resume1'

# --ckp_path './ckps/resume' \
185 changes: 185 additions & 0 deletions trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import os
import json
from torch.utils.tensorboard.writer import SummaryWriter
import torch
from logger import setup_logger
from tqdm import tqdm
import numpy as np


class Trainer:
def __init__(
self,
device,
model,
criterion,
optimizer,
args,
metrics_dict,
sched=None,
use_writer=True
):
self.device = device
self.model = model.to(device)
self.optimizer = optimizer
self.sched = sched
self.criterion = criterion
self.args = args
self.base_epoch = args.base_epoch
self._pre_config()

if args.ckp_path is not None:
json_info = self.load_torch_model()
self.base_epoch = json_info['epoch'] + 1

if use_writer:
self.writer = SummaryWriter(log_dir=args.log_dir)
else:
self.writer = None
self.metrics_dict = metrics_dict

def _pre_config(self):
# out ckp path config
os.makedirs(self.args.out_ckp_path, exist_ok=True)
with open(os.path.join(self.args.out_ckp_path, 'args.json'), 'w') as f:
json.dump(self.args.__dict__, f)
self.logger = setup_logger('trainer', os.path.join(self.args.out_ckp_path, 'trainer.log'))

def fit(self, train_loader, valid_loader, test_loader=None):
best_auc = 0.0
for epoch in range(self.base_epoch, self.base_epoch + self.args.epochs):
losses = []
preds = []
targets = []

prograss_bar = tqdm(train_loader, leave=False)
for x, y in prograss_bar:
x = x.to(self.device)
y = y.to(self.device)

sim_logits = self.model.forward(x)
loss = self.criterion(sim_logits, y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
losses.append(loss.item())

prograss_bar.set_postfix_str(f'loss={loss.item()}')

preds.append(sim_logits.detach())
targets.append(y.detach())

preds = torch.concat(preds, dim=0)
targets = torch.concat(targets, dim=0)
loss_val = np.mean(losses)
train_metrics_val = self._compute_metrics(preds, targets)
train_metrics_val.update({'loss': loss_val})

if self.sched is not None:
self.sched.step()

valid_metrics_val = self.valid(valid_loader)
if valid_metrics_val['auc'] > best_auc:
best_auc = valid_metrics_val['auc']
self.save_torch_model(
{'epoch': epoch, 'train_auc': train_metrics_val['auc'], 'valid_auc': valid_metrics_val['auc']}
)

self._write_tensorboard(train_metrics_val, valid_metrics_val, self.optimizer.param_groups[0]['lr'], epoch)
self._print_metrics_val(epoch, train_metrics_val, valid_metrics_val)

def _print_metrics_val(self, epoch, train_met_val, valid_met_val, test_met_val=None):
info = f"[epoch: {epoch}] "
for k, v in train_met_val.items():
info += f"{k}: {v:.6f}, "
for k, v in valid_met_val.items():
info += f"{k}: {v:.6f}, "
if test_met_val is not None:
for k, v in test_met_val.items():
info += f"{k}: {v:.6f}, "
self.logger.info(info)

def _write_tensorboard(self, train_met_val, valid_met_val, lr, step, test_met_val=None):
"""
Args:
cate: train/valid/test
metrics_val: a dict. {matric_name: value}
"""
def _helper(cate, met_val):
for k, v in met_val.items():
self.writer.add_scalar(f'{k}/{cate}', v, step)
_helper('train', train_met_val)
_helper('valid', valid_met_val)
if test_met_val is not None:
_helper('test', test_met_val)
self.writer.add_scalar('lr', lr, step)

@torch.no_grad()
def _compute_metrics(self, preds: torch.Tensor, targets: torch.Tensor) -> dict:
"""
Return:
`{acc: 0.99999, auc: 0.9999, ...}`
"""
metrics_res = {}
for k, fn in self.metrics_dict.items():
metrics_res[k.name] = fn(preds, targets).item()
return metrics_res

@torch.no_grad()
def valid(self, valid_loader):
self.model.eval()

with torch.no_grad():
losses = []
targets = []
preds = []
prograss_bar = tqdm(valid_loader, leave=False)
for x, y in prograss_bar:
x = x.to(self.device)
y = y.to(self.device)
sim_logits = self.model(x)
loss = self.criterion(sim_logits, y).item()
losses.append(loss)
targets.append(y.detach())
preds.append(sim_logits.detach())
prograss_bar.set_postfix_str(f'loss={loss}')
self.model.train()

preds = torch.concat(preds, dim=0)
targets = torch.concat(targets, dim=0)
eval_metrics_val = self._compute_metrics(preds, targets)
eval_metrics_val.update({'loss': np.mean(losses)})
return eval_metrics_val

def load_torch_model(self) -> dict:
"""load state and return checkpoint info"""
path_dir = self.args.ckp_path
info = f'- loaded from {path_dir}, for model'
self.model.load_state_dict(torch.load(os.path.join(path_dir, 'model.pth')))

if self.optimizer is not None:
self.optimizer.load_state_dict(torch.load(os.path.join(path_dir, 'opt.pth')))
info += ', for opt'
if self.args.use_sched_ckp and self.sched is not None:
self.sched.load_state_dict(torch.load(os.path.join(path_dir, 'sched.pth')))
info += ', for sched'

with open(os.path.join(path_dir, 'config.json'), 'r') as f:
config = json.load(f)
self.logger.info(info)
return config

def save_torch_model(self, json_info: dict):
"""
Args:
json_info:
"""
torch.save(self.model.state_dict(), os.path.join(self.args.out_ckp_path, 'model.pth'))
torch.save(self.optimizer.state_dict(), os.path.join(self.args.out_ckp_path, 'opt.pth'))
if self.sched is not None:
torch.save(self.sched.state_dict(), os.path.join(self.args.out_ckp_path, 'sched.pth'))

with open(os.path.join(self.args.out_ckp_path, 'config.json'), 'w') as f:
json.dump(json_info, f)

self.logger.info(f'- saved model in {self.args.out_ckp_path}')
83 changes: 83 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
from torch import nn
import os
import argparse
import numpy as np
import random
import json
import json


def get_parser():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--seed', type=int, default=42, help="random seed")

# train related
parser.add_argument('--epochs', type=int, default=30, help='# of epoch')
parser.add_argument('--base_epoch', default=0, type=int, help='start epoch number, used for ckp')
parser.add_argument('--batch_size', type=int, default=1024, help='# samples in batch')
parser.add_argument('--test_ratio', type=int, default=0.2, help='# samples in batch')
parser.add_argument('--device', type=str, default='cuda:2', help='cpu, mps, cuda:0, cuda:x')
parser.add_argument('--use_pretrained_emb', type=bool, default=False, help='')

# optimizer ralated
parser.add_argument('--lr', type=float, default=0.1, help='initial learning rate for adam')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='checkpoint path')
parser.add_argument('--use_sched_ckp', type=bool, default=False, help='if use the checkpoint of scheduler')

# path related
parser.add_argument('--data_path', type=str, default='./data/feature/', help='data path')
parser.add_argument('--dict_prop_file', type=str, default='./data/dict_prop.json')
parser.add_argument('--ckp_path', type=str, default=None, help='checkpoint path')
parser.add_argument('--out_ckp_path', type=str, default='./ckps/demo', help='checkpoint path')
parser.add_argument('--log_dir', type=str, default='./log/demo', help='checkpoint path')

return parser.parse_args()


def save_torch_model(model_state, opt_state, sched_state, path_dir, cfg: dict):
"""
Args:
model_state: could not be None
opt_state: could not be None
sched_state: None or not.
path_dir: directory of ckp path. format `expflag_epoch-{base_epoch}-{cur_epoch}`
is recommended.
"""
os.makedirs(path_dir, exist_ok=True)

torch.save(model_state, os.path.join(path_dir, 'model.pth'))
torch.save(opt_state, os.path.join(path_dir, 'opt.pth'))
if sched_state is not None:
torch.save(sched_state, os.path.join(path_dir, 'sched.pth'))

with open(os.path.join(path_dir, 'config.json'), 'w') as f:
json.dump(cfg, f)

print(f'saved model in {path_dir}')


def load_torch_model(model: nn.Module, opt, sched, path_dir) -> dict:
"""load state and return checkpoint config"""
info = f'[info] loaded from {path_dir}, for model'
model.load_state_dict(torch.load(os.path.join(path_dir, 'model.pth')))

if opt is not None:
opt.load_state_dict(torch.load(os.path.join(path_dir, 'opt.pth')))
info += ', for opt'
if sched is not None:
sched.load_state_dict(torch.load(os.path.join(path_dir, 'sched.pth')))
info += ', for sched'

with open(os.path.join(path_dir, 'config.json'), 'r') as f:
config = json.load(f)
print(info)
return config


def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True

0 comments on commit efcd78a

Please sign in to comment.