Skip to content

Commit

Permalink
Fix yolov3 train
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Jun 20, 2024
1 parent b2b4158 commit ebbbd50
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchbenchmark/models/yolov3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
self.num_epochs = 1
self.train_num_batch = 1
self.prefetch = True
if test == "train":
if test == "eval" or self.dargs.accuracy:
self.model, self.example_inputs = self.prep_eval()
elif test == "train":
train_args = split(
f"--data {DATA_DIR}/coco128.data --img 416 --batch {self.batch_size} --nosave --notest \
--epochs {self.num_epochs} --device {self.device_str} --weights '' \
Expand All @@ -64,8 +66,6 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
self.training_loop, self.model, self.example_inputs = prepare_training_loop(
train_args
)
elif test == "eval":
self.model, self.example_inputs = self.prep_eval()
self.amp_context = nullcontext

def prep_eval(self):
Expand Down

0 comments on commit ebbbd50

Please sign in to comment.