diff --git a/torchbenchmark/models/yolov3/__init__.py b/torchbenchmark/models/yolov3/__init__.py index 3a6bd6e057..9d029d07f6 100644 --- a/torchbenchmark/models/yolov3/__init__.py +++ b/torchbenchmark/models/yolov3/__init__.py @@ -54,7 +54,9 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): self.num_epochs = 1 self.train_num_batch = 1 self.prefetch = True - if test == "train": + if test == "eval" or self.dargs.accuracy: + self.model, self.example_inputs = self.prep_eval() + elif test == "train": train_args = split( f"--data {DATA_DIR}/coco128.data --img 416 --batch {self.batch_size} --nosave --notest \ --epochs {self.num_epochs} --device {self.device_str} --weights '' \ @@ -64,8 +66,6 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): self.training_loop, self.model, self.example_inputs = prepare_training_loop( train_args ) - elif test == "eval": - self.model, self.example_inputs = self.prep_eval() self.amp_context = nullcontext def prep_eval(self):