diff --git a/classification/engine.py b/classification/engine.py index 65eb52e..cc31103 100644 --- a/classification/engine.py +++ b/classification/engine.py @@ -68,7 +68,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, @torch.no_grad() -def evaluate(data_loader, model, device): +def evaluate(data_loader, model, device, fp32=False): criterion = torch.nn.CrossEntropyLoss() metric_logger = utils.MetricLogger(delimiter=" ") @@ -82,7 +82,7 @@ def evaluate(data_loader, model, device): target = target.to(device, non_blocking=True) # compute output - with torch.cuda.amp.autocast(): + with torch.cuda.amp.autocast(enabled=not fp32): output = model(images) loss = criterion(output, target) diff --git a/classification/main.py b/classification/main.py index c4eccbf..f50e591 100644 --- a/classification/main.py +++ b/classification/main.py @@ -361,7 +361,7 @@ def main(args): loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: - test_stats = evaluate(data_loader_val, model, device) + test_stats = evaluate(data_loader_val, model, device, fp32=args.fp32_resume) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") return