-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a977dda
commit efcd78a
Showing
5 changed files
with
347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import logging | ||
|
||
|
||
def setup_logger(name, log_file, level=logging.INFO): | ||
"""Function to setup as many loggers as you want""" | ||
|
||
# 创建一个logger | ||
logger = logging.getLogger(name) | ||
logger.setLevel(level) | ||
|
||
# 创建用于写入日志文件的handler | ||
file_handler = logging.FileHandler(log_file) | ||
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s | %(message)s')) | ||
|
||
# 创建用于输出到控制台的handler | ||
console_handler = logging.StreamHandler() | ||
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s | %(message)s')) | ||
|
||
# 添加handlers到logger | ||
logger.addHandler(file_handler) | ||
logger.addHandler(console_handler) | ||
|
||
return logger | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from torchmetrics import AUROC, Accuracy, Precision, Recall, F1Score | ||
from enum import Enum | ||
|
||
|
||
class Metrics(Enum): | ||
acc = 0 | ||
auc = 1 | ||
precision = 2 | ||
recall = 3 | ||
f1 = 4 | ||
|
||
|
||
def get_metric_dict(num_classes, metrics): | ||
""" | ||
Args: | ||
num_classes: number of classes | ||
metrics_dict: `{Metrics.acc: None, Metrics.auc: None,....}` | ||
Return: | ||
`{Metrics.acc: Accuracy(num_classes=num_classes, average='macro'),...}` | ||
""" | ||
task_type = 'binary' if num_classes == 2 else 'multiclass' | ||
metrics_dict = {} | ||
|
||
for key in metrics: | ||
if key == Metrics.acc: | ||
metrics_dict[key] = Accuracy(task=task_type, average='macro', num_classes=num_classes) | ||
elif key == Metrics.auc: | ||
metrics_dict[key] = AUROC(num_classes=num_classes, task=task_type) | ||
elif key == Metrics.precision: | ||
metrics_dict[key] = Precision(num_classes=num_classes, average='macro', task=task_type) | ||
elif key == Metrics.recall: | ||
metrics_dict[key] = Recall(num_classes=num_classes, average='macro', task=task_type) | ||
elif key == Metrics.f1: | ||
metrics_dict[key] = F1Score(num_classes=num_classes, average='macro', task=task_type) | ||
else: | ||
raise ValueError(f"unsupported metric name {key.name}") | ||
|
||
return metrics_dict | ||
|
||
|
||
metrics = get_metric_dict(3, [Metrics.acc, Metrics.auc, Metrics.precision, Metrics.recall, Metrics.f1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
python demo.py \ | ||
--seed 42 \ | ||
--epoch 5 \ | ||
--batch_size 128 \ | ||
--device 'cuda:2' \ | ||
--lr 1e-3 \ | ||
--weight_decay 0.0 \ | ||
--data_path './data/feature/' \ | ||
--ckp_path './ckps/resume1' \ | ||
--use_sched_ckp false \ | ||
--out_ckp_path './ckps/resume2' \ | ||
--log_dir './log/resume1' | ||
|
||
# --ckp_path './ckps/resume' \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
import os | ||
import json | ||
from torch.utils.tensorboard.writer import SummaryWriter | ||
import torch | ||
from logger import setup_logger | ||
from tqdm import tqdm | ||
import numpy as np | ||
|
||
|
||
class Trainer: | ||
def __init__( | ||
self, | ||
device, | ||
model, | ||
criterion, | ||
optimizer, | ||
args, | ||
metrics_dict, | ||
sched=None, | ||
use_writer=True | ||
): | ||
self.device = device | ||
self.model = model.to(device) | ||
self.optimizer = optimizer | ||
self.sched = sched | ||
self.criterion = criterion | ||
self.args = args | ||
self.base_epoch = args.base_epoch | ||
self._pre_config() | ||
|
||
if args.ckp_path is not None: | ||
json_info = self.load_torch_model() | ||
self.base_epoch = json_info['epoch'] + 1 | ||
|
||
if use_writer: | ||
self.writer = SummaryWriter(log_dir=args.log_dir) | ||
else: | ||
self.writer = None | ||
self.metrics_dict = metrics_dict | ||
|
||
def _pre_config(self): | ||
# out ckp path config | ||
os.makedirs(self.args.out_ckp_path, exist_ok=True) | ||
with open(os.path.join(self.args.out_ckp_path, 'args.json'), 'w') as f: | ||
json.dump(self.args.__dict__, f) | ||
self.logger = setup_logger('trainer', os.path.join(self.args.out_ckp_path, 'trainer.log')) | ||
|
||
def fit(self, train_loader, valid_loader, test_loader=None): | ||
best_auc = 0.0 | ||
for epoch in range(self.base_epoch, self.base_epoch + self.args.epochs): | ||
losses = [] | ||
preds = [] | ||
targets = [] | ||
|
||
prograss_bar = tqdm(train_loader, leave=False) | ||
for x, y in prograss_bar: | ||
x = x.to(self.device) | ||
y = y.to(self.device) | ||
|
||
sim_logits = self.model.forward(x) | ||
loss = self.criterion(sim_logits, y) | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
losses.append(loss.item()) | ||
|
||
prograss_bar.set_postfix_str(f'loss={loss.item()}') | ||
|
||
preds.append(sim_logits.detach()) | ||
targets.append(y.detach()) | ||
|
||
preds = torch.concat(preds, dim=0) | ||
targets = torch.concat(targets, dim=0) | ||
loss_val = np.mean(losses) | ||
train_metrics_val = self._compute_metrics(preds, targets) | ||
train_metrics_val.update({'loss': loss_val}) | ||
|
||
if self.sched is not None: | ||
self.sched.step() | ||
|
||
valid_metrics_val = self.valid(valid_loader) | ||
if valid_metrics_val['auc'] > best_auc: | ||
best_auc = valid_metrics_val['auc'] | ||
self.save_torch_model( | ||
{'epoch': epoch, 'train_auc': train_metrics_val['auc'], 'valid_auc': valid_metrics_val['auc']} | ||
) | ||
|
||
self._write_tensorboard(train_metrics_val, valid_metrics_val, self.optimizer.param_groups[0]['lr'], epoch) | ||
self._print_metrics_val(epoch, train_metrics_val, valid_metrics_val) | ||
|
||
def _print_metrics_val(self, epoch, train_met_val, valid_met_val, test_met_val=None): | ||
info = f"[epoch: {epoch}] " | ||
for k, v in train_met_val.items(): | ||
info += f"{k}: {v:.6f}, " | ||
for k, v in valid_met_val.items(): | ||
info += f"{k}: {v:.6f}, " | ||
if test_met_val is not None: | ||
for k, v in test_met_val.items(): | ||
info += f"{k}: {v:.6f}, " | ||
self.logger.info(info) | ||
|
||
def _write_tensorboard(self, train_met_val, valid_met_val, lr, step, test_met_val=None): | ||
""" | ||
Args: | ||
cate: train/valid/test | ||
metrics_val: a dict. {matric_name: value} | ||
""" | ||
def _helper(cate, met_val): | ||
for k, v in met_val.items(): | ||
self.writer.add_scalar(f'{k}/{cate}', v, step) | ||
_helper('train', train_met_val) | ||
_helper('valid', valid_met_val) | ||
if test_met_val is not None: | ||
_helper('test', test_met_val) | ||
self.writer.add_scalar('lr', lr, step) | ||
|
||
@torch.no_grad() | ||
def _compute_metrics(self, preds: torch.Tensor, targets: torch.Tensor) -> dict: | ||
""" | ||
Return: | ||
`{acc: 0.99999, auc: 0.9999, ...}` | ||
""" | ||
metrics_res = {} | ||
for k, fn in self.metrics_dict.items(): | ||
metrics_res[k.name] = fn(preds, targets).item() | ||
return metrics_res | ||
|
||
@torch.no_grad() | ||
def valid(self, valid_loader): | ||
self.model.eval() | ||
|
||
with torch.no_grad(): | ||
losses = [] | ||
targets = [] | ||
preds = [] | ||
prograss_bar = tqdm(valid_loader, leave=False) | ||
for x, y in prograss_bar: | ||
x = x.to(self.device) | ||
y = y.to(self.device) | ||
sim_logits = self.model(x) | ||
loss = self.criterion(sim_logits, y).item() | ||
losses.append(loss) | ||
targets.append(y.detach()) | ||
preds.append(sim_logits.detach()) | ||
prograss_bar.set_postfix_str(f'loss={loss}') | ||
self.model.train() | ||
|
||
preds = torch.concat(preds, dim=0) | ||
targets = torch.concat(targets, dim=0) | ||
eval_metrics_val = self._compute_metrics(preds, targets) | ||
eval_metrics_val.update({'loss': np.mean(losses)}) | ||
return eval_metrics_val | ||
|
||
def load_torch_model(self) -> dict: | ||
"""load state and return checkpoint info""" | ||
path_dir = self.args.ckp_path | ||
info = f'- loaded from {path_dir}, for model' | ||
self.model.load_state_dict(torch.load(os.path.join(path_dir, 'model.pth'))) | ||
|
||
if self.optimizer is not None: | ||
self.optimizer.load_state_dict(torch.load(os.path.join(path_dir, 'opt.pth'))) | ||
info += ', for opt' | ||
if self.args.use_sched_ckp and self.sched is not None: | ||
self.sched.load_state_dict(torch.load(os.path.join(path_dir, 'sched.pth'))) | ||
info += ', for sched' | ||
|
||
with open(os.path.join(path_dir, 'config.json'), 'r') as f: | ||
config = json.load(f) | ||
self.logger.info(info) | ||
return config | ||
|
||
def save_torch_model(self, json_info: dict): | ||
""" | ||
Args: | ||
json_info: | ||
""" | ||
torch.save(self.model.state_dict(), os.path.join(self.args.out_ckp_path, 'model.pth')) | ||
torch.save(self.optimizer.state_dict(), os.path.join(self.args.out_ckp_path, 'opt.pth')) | ||
if self.sched is not None: | ||
torch.save(self.sched.state_dict(), os.path.join(self.args.out_ckp_path, 'sched.pth')) | ||
|
||
with open(os.path.join(self.args.out_ckp_path, 'config.json'), 'w') as f: | ||
json.dump(json_info, f) | ||
|
||
self.logger.info(f'- saved model in {self.args.out_ckp_path}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import torch | ||
from torch import nn | ||
import os | ||
import argparse | ||
import numpy as np | ||
import random | ||
import json | ||
import json | ||
|
||
|
||
def get_parser(): | ||
parser = argparse.ArgumentParser(description='') | ||
parser.add_argument('--seed', type=int, default=42, help="random seed") | ||
|
||
# train related | ||
parser.add_argument('--epochs', type=int, default=30, help='# of epoch') | ||
parser.add_argument('--base_epoch', default=0, type=int, help='start epoch number, used for ckp') | ||
parser.add_argument('--batch_size', type=int, default=1024, help='# samples in batch') | ||
parser.add_argument('--test_ratio', type=int, default=0.2, help='# samples in batch') | ||
parser.add_argument('--device', type=str, default='cuda:2', help='cpu, mps, cuda:0, cuda:x') | ||
parser.add_argument('--use_pretrained_emb', type=bool, default=False, help='') | ||
|
||
# optimizer ralated | ||
parser.add_argument('--lr', type=float, default=0.1, help='initial learning rate for adam') | ||
parser.add_argument('--weight_decay', type=float, default=1e-4, help='checkpoint path') | ||
parser.add_argument('--use_sched_ckp', type=bool, default=False, help='if use the checkpoint of scheduler') | ||
|
||
# path related | ||
parser.add_argument('--data_path', type=str, default='./data/feature/', help='data path') | ||
parser.add_argument('--dict_prop_file', type=str, default='./data/dict_prop.json') | ||
parser.add_argument('--ckp_path', type=str, default=None, help='checkpoint path') | ||
parser.add_argument('--out_ckp_path', type=str, default='./ckps/demo', help='checkpoint path') | ||
parser.add_argument('--log_dir', type=str, default='./log/demo', help='checkpoint path') | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def save_torch_model(model_state, opt_state, sched_state, path_dir, cfg: dict): | ||
""" | ||
Args: | ||
model_state: could not be None | ||
opt_state: could not be None | ||
sched_state: None or not. | ||
path_dir: directory of ckp path. format `expflag_epoch-{base_epoch}-{cur_epoch}` | ||
is recommended. | ||
""" | ||
os.makedirs(path_dir, exist_ok=True) | ||
|
||
torch.save(model_state, os.path.join(path_dir, 'model.pth')) | ||
torch.save(opt_state, os.path.join(path_dir, 'opt.pth')) | ||
if sched_state is not None: | ||
torch.save(sched_state, os.path.join(path_dir, 'sched.pth')) | ||
|
||
with open(os.path.join(path_dir, 'config.json'), 'w') as f: | ||
json.dump(cfg, f) | ||
|
||
print(f'saved model in {path_dir}') | ||
|
||
|
||
def load_torch_model(model: nn.Module, opt, sched, path_dir) -> dict: | ||
"""load state and return checkpoint config""" | ||
info = f'[info] loaded from {path_dir}, for model' | ||
model.load_state_dict(torch.load(os.path.join(path_dir, 'model.pth'))) | ||
|
||
if opt is not None: | ||
opt.load_state_dict(torch.load(os.path.join(path_dir, 'opt.pth'))) | ||
info += ', for opt' | ||
if sched is not None: | ||
sched.load_state_dict(torch.load(os.path.join(path_dir, 'sched.pth'))) | ||
info += ', for sched' | ||
|
||
with open(os.path.join(path_dir, 'config.json'), 'r') as f: | ||
config = json.load(f) | ||
print(info) | ||
return config | ||
|
||
|
||
def setup_seed(seed): | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
np.random.seed(seed) | ||
random.seed(seed) | ||
torch.backends.cudnn.deterministic = True |