Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Echo pretrain #5

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ summary*
run*
*.pth
*.png
*.sh
*.sh
tags
3 changes: 3 additions & 0 deletions constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

DATASET_ERROR_VERBOSITY = int(os.getenv("DATASET_ERROR_VERBOSITY", "0"))
46 changes: 39 additions & 7 deletions engine_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,23 @@
import sys
from typing import Iterable

import wandb
import torch
from timeit import default_timer

import utils.misc as misc
import utils.lr_sched as lr_sched


def train_one_epoch(model: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
log_writer=None,
args=None):
def train_one_epoch(
model: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int, loss_scaler,
log_writer=None,
args=None,
):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
Expand All @@ -37,14 +43,18 @@ def train_one_epoch(model: torch.nn.Module,
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))

for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
epoch_start = data_start = default_timer()
for data_iter_step, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
data_time = default_timer() - data_start

# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

samples = samples.to(device, non_blocking=True)

fwd_bwd_start = default_timer()

with torch.cuda.amp.autocast():
loss, _, _ = model(samples, mask_ratio=args.mask_ratio)

Expand All @@ -57,9 +67,28 @@ def train_one_epoch(model: torch.nn.Module,
loss /= accum_iter
loss_scaler(loss, optimizer, parameters=model.parameters(),
update_grad=(data_iter_step + 1) % accum_iter == 0)
fwd_bwd_time = default_timer() - fwd_bwd_start

if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()

num_frames = samples.size(0) * samples.size(2)
fps = num_frames / (fwd_bwd_time + data_time)
if data_iter_step % 10 == 0 and misc.is_main_process():
print(
f"step {data_iter_step} | loss {loss_value:.4f} | fwdbwd_t {fwd_bwd_time:.4f} | "
f"data_t {data_time:.4f} | fps {fps:.4f}"
)
wandb.log(
{
'loss': loss_value,
'fwdbwd_t': fwd_bwd_time,
'data_t': data_time,
'fps': fps,
},
step=data_iter_step,
)

torch.cuda.synchronize()

metric_logger.update(loss=loss_value)
Expand All @@ -76,8 +105,11 @@ def train_one_epoch(model: torch.nn.Module,
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', lr, epoch_1000x)

data_start = default_timer()

epoch_time = default_timer() - epoch_start
print(f"Epoch time: {epoch_time}")
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
69 changes: 39 additions & 30 deletions main_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@

import utils.misc as misc
from utils.misc import NativeScalerWithGradNormCount as NativeScaler
from utils.datasets import VideoFrameDataset
from utils.video_frame_dataset import VideoFrameDataset

import models_mae3d
from engine_pretrain import train_one_epoch
import wandb

wandb.init(project="SuTr", entity="cyrilzakka")

def get_args_parser():
parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
Expand Down Expand Up @@ -61,14 +60,18 @@ def get_args_parser():
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
parser.add_argument('--warmup_epochs', type=int, default=0, metavar='N',
help='epochs to warmup LR')

# Dataset parameters
parser.add_argument('--data_path', default='/home/cyril/Datasets/MAE/', type=str,
help='dataset path')
parser.add_argument('--num_segments', default=4, type=int)
parser.add_argument('--frames_per_segment', default=4, type=int)
parser.add_argument(
'--data_path',
default='/scratch/groups/willhies/echo/echoai/combined.csv',
type=str,
help='dataset path',
)
parser.add_argument('--num_frames', default=32, type=int)
parser.add_argument('--stride', default=8, type=int)
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
Expand Down Expand Up @@ -110,17 +113,12 @@ def main(args):

cudnn.benchmark = True

# simple augmentation
transform_train = transforms.Compose([
transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=Image.BICUBIC), # 3 is bicubic
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
dataset_train = VideoFrameDataset(root_path=os.path.join(args.data_path, 'train'), annotationfile_path=os.path.join(args.data_path, 'ledger.csv'),
num_segments= args.num_segments,
frames_per_segment = args.frames_per_segment,
transform = transform_train,
test_mode = False)
dataset_train = VideoFrameDataset(
ledger_path=args.data_path,
num_frames=args.num_frames,
stride=args.stride,
do_augmentation=True,
)
print(dataset_train)

if True: # args.distributed:
Expand All @@ -140,16 +138,19 @@ def main(args):
pin_memory=args.pin_mem,
drop_last=True,
)

# define the model
model = models_mae3d.__dict__[args.model](num_frames=int(args.num_segments*args.frames_per_segment), norm_pix_loss=args.norm_pix_loss)
model = models_mae3d.__dict__[args.model](
num_frames=args.num_frames,
norm_pix_loss=args.norm_pix_loss,
)
model.to(device)

model_without_ddp = model
print("Model = %s" % str(model_without_ddp))
# print("Model = %s" % str(model_without_ddp))

eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()

if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256

Expand All @@ -162,15 +163,22 @@ def main(args):
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module

# following timm: set wd as 0 for bias and norm layers
param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
print(optimizer)
loss_scaler = NativeScaler()

misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
misc.load_model(
args=args,
model_without_ddp=model_without_ddp,
optimizer=optimizer,
loss_scaler=loss_scaler,
)

wandb.init(project="SuTr", entity="akashc")
wandb.config.update(args)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
Expand All @@ -182,14 +190,16 @@ def main(args):
log_writer=None,
args=args
)
if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs):
if args.output_dir:
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)

log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,}

log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,
}

if misc.is_main_process():
wandb.log(log_stats)

Expand All @@ -201,7 +211,6 @@ def main(args):
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
wandb.config.update(args)

if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion utils/lr_sched.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate with half-cycle cosine after warmup"""
if epoch < args.warmup_epochs:
lr = args.lr * epoch / args.warmup_epochs
lr = args.lr * epoch / args.warmup_epochs
else:
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
Expand Down
6 changes: 4 additions & 2 deletions utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
import torch.distributed as dist
from torch._six import inf
from torch import inf


class SmoothedValue(object):
Expand Down Expand Up @@ -230,6 +230,8 @@ def init_distributed_mode(args):
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
else:
print('Not using distributed mode')
setup_for_distributed(is_master=True) # hack
Expand Down Expand Up @@ -337,4 +339,4 @@ def all_reduce_mean(x):
x_reduce /= world_size
return x_reduce.item()
else:
return x
return x
Loading