Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sstandardized Logging and Fixed a Bug #71

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions trailmet/algorithms/binarize/BNNBN.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,19 @@ def __init__(self, model, dataloaders, **kwargs):
self.weight_decay = self.kwargs.get('weight_decay', '0')
self.learning_rate = self.kwargs.get('learning_rate', '0.001')

def prepare_dirs(self):
if not os.path.exists('log'):
print('Creating Logging Directory...')
os.mkdir('log')
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join('log/log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

def compress_model(self):
self.prepare_dirs()
if not torch.cuda.is_available():
sys.exit(1)
start_t = time.time()
Expand Down Expand Up @@ -108,6 +120,7 @@ def compress_model(self):

if self.pretrained:
print('* loading pretrained weight {}'.format(self.pretrained))
logging.info(f'loading pretrained weight {self.pretrained}')
pretrain_student = torch.load(args.pretrained)
if 'state_dict' in pretrain_student.keys():
pretrain_student = pretrain_student['state_dict']
Expand All @@ -122,20 +135,24 @@ def compress_model(self):
checkpoint_tar = os.path.join(self.save, 'checkpoint.pth.tar')
if os.path.exists(checkpoint_tar):
print('loading checkpoint {} ..........'.format(checkpoint_tar))
logging.info('loading checkpoint {} ..........'.format(checkpoint_tar))
checkpoint = torch.load(checkpoint_tar)
start_epoch = checkpoint['epoch']
best_top1_acc = checkpoint['best_top1_acc']
self.model_student.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
print("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch']))
logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch']))
else:
raise ValueError('no checkpoint for resume')

if self.loss_type == 'kd':
if not classes_in_teacher == self.num_classes:
self.validate('teacher', self.val_loader, model_teacher, criterion)


logging.info('epoch, train accuracy, train loss, val accuracy, val loss')

# train the model
epoch = start_epoch
while epoch < self.epochs:
Expand All @@ -150,7 +167,9 @@ def compress_model(self):
raise ValueError('unsupport loss_type')

valid_obj, valid_top1_acc, valid_top5_acc = self.validate(epoch, self.val_loader, self.model_student, criterion)


logging.info("{}, {}, {}, {}, {}".format(epoch, train_top1_acc, train_obj, valid_top1_acc.item(), valid_obj))

is_best = False
if valid_top1_acc > best_top1_acc:
best_top1_acc = valid_top1_acc
Expand All @@ -168,7 +187,9 @@ def compress_model(self):

training_time = (time.time() - start_t) / 3600
print('total training time = {} hours'.format(training_time))
logging.info('total training time = {} hours'.format(training_time))
print('* best acc = {}'.format(best_top1_acc))
logging.info('* best acc = {}'.format(best_top1_acc))


def train_kd(self, epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler):
Expand Down
Loading