Skip to content

Commit

Permalink
Simplify the multiple batches design by removing the "train_dynamic" …
Browse files Browse the repository at this point in the history
…test

Summary: Let user control the multiple batch test by simply specifying the `--num-batch` command.

Reviewed By: Bucero

Differential Revision: D50092718

fbshipit-source-id: 15d0e3d85e98982697caf92b2f5242e2cba97711
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Oct 10, 2023
1 parent d4f79d9 commit fd7f14e
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 24 deletions.
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _validate_profile_options(profile_options: str):
parser.add_argument(
"-t",
"--test",
choices=["eval", "train", "train_dynamic"],
choices=["eval", "train"],
default="eval",
help="Which test to run.",
)
Expand Down
2 changes: 1 addition & 1 deletion torchbenchmark/util/extra_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def parse_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', ext
parser.add_argument("--accuracy", action="store_true", help="Check accuracy of the model only instead of running the performance test.")
parser.add_argument("--use_cosine_similarity", action='store_true', help="use cosine similarity for correctness check")
parser.add_argument("--quant-engine", choices=QUANT_ENGINES, default='x86', help=f"choose quantization engine for fx_int8 precision from {QUANT_ENGINES}")
parser.add_argument("--num-batch", type=int, help="Number of batches if running the train_dynamic test.")
parser.add_argument("--num-batch", type=int, help="Number of batches if running the multi-batch train test.")
dargs, opt_args = parser.parse_known_args(extra_args)
if not check_precision(model, dargs.precision):
raise NotImplementedError(f"precision value: {dargs.precision}, "
Expand Down
2 changes: 1 addition & 1 deletion torchbenchmark/util/framework/detectron2/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, variant, test, device, batch_size=None, extra_args=[]):
self.model = instantiate(cfg.model).to(self.device)

# setup model and return the dataloader
if self.test == "train" or self.test == "train_dynamic":
if self.test == "train":
if hasattr(self, "FCOS_USE_BN") and self.FCOS_USE_BN:
raise NotImplementedError("FCOS train is not supported by upstream detectron2. " \
"See GH Issue: https://github.com/facebookresearch/detectron2/issues/4369.")
Expand Down
4 changes: 1 addition & 3 deletions torchbenchmark/util/framework/huggingface/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class HuggingFaceModel(BenchmarkModel):
HF_MODEL = True
# Default eval precision on CUDA device is fp16(half mode)
DEFAULT_EVAL_CUDA_PRECISION = "fp16"
# When running the train_dynamic test, run 100 batches of input
DEFAULT_NUM_BATCH = 10

# If you suffix a model with '_generate', we will instead wrap the
# unsuffixed model with GenerationWrapper which will make it do
Expand All @@ -74,7 +72,7 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]):
self.is_generate = False
self.unqual_name = name
name = self.unqual_name # we don't want to refer to the qualified name anymore
if test == "train" or test == "train_dynamic":
if test == "train":
self.max_length = class_models[name][0]
elif test == "eval":
self.max_length = class_models[name][1]
Expand Down
4 changes: 1 addition & 3 deletions torchbenchmark/util/framework/timm/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ class TimmModel(BenchmarkModel):
DEFAULT_EVAL_BSIZE = None
# Default eval precision on CUDA device is fp16
DEFAULT_EVAL_CUDA_PRECISION = "fp16"
# When running the train_dynamic test, run 100 batches of input
DEFAULT_NUM_BATCH = 10

def __init__(self, model_name, test, device, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, batch_size=batch_size, extra_args=extra_args)
Expand All @@ -29,7 +27,7 @@ def __init__(self, model_name, test, device, batch_size=None, extra_args=[]):
self.model.to(
device=self.device
)
if test == "train" or test == "train_dynamic":
if test == "train":
self.model.train()
elif test == "eval":
self.model.eval()
Expand Down
4 changes: 1 addition & 3 deletions torchbenchmark/util/framework/vision/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class TorchVisionModel(BenchmarkModel):
DEFAULT_EVAL_CUDA_PRECISION = "fp16"
# Whether to skip the opt zero grad
SKIP_ZERO_GRAD = False
# When running the train_dynamic test, run 100 batches of input
DEFAULT_NUM_BATCH = 10

def __init__(self, model_name, test, device, batch_size=None, weights=None, extra_args=[]):
super().__init__(test=test, device=device, batch_size=batch_size, extra_args=extra_args)
Expand All @@ -29,7 +27,7 @@ def __init__(self, model_name, test, device, batch_size=None, weights=None, extr
else:
self.model = getattr(models, model_name)(weights=weights).to(self.device)
self.example_inputs = (torch.randn((self.batch_size, 3, 224, 224)).to(self.device), )
if test == "train" or test == "train_dynamic":
if test == "train":
# compute loss
with torch.no_grad():
self.example_outputs = (torch.rand_like(self.model(*self.example_inputs)), )
Expand Down
20 changes: 8 additions & 12 deletions torchbenchmark/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,14 @@ def __init__(self, test: str, device: str, batch_size: Optional[int]=None, extra
self.metadata = self._load_metadata()
self.test = test
# sanity checks of the options
assert self.test == "train" or self.test == "eval" or self.test == "train_dynamic", \
f"Test must be 'train', 'train_dynamic', or 'eval', but provided {self.test}."
assert self.test == "train_dynamic" and is_staged_train_test(self) or (not self.test == "train_dynamic"), \
f"Dynamic shapes must be implemented with staged train test."
assert self.test == "train" or self.test == "eval", \
f"Test must be 'train' or 'eval', but provided {self.test}."
self.device = device
self.extra_args = extra_args
self.opt = None
self._skip_by_device_name()
# contexts to run in the test function
if self.test == "train" or self.test == "train_dynamic":
if self.test == "train":
# In train test, there are run contexts that should only be applied for forward/backward/optimizer stage
# For example, amp only applies for the forward stage
self.forward_contexts = []
Expand Down Expand Up @@ -164,12 +162,11 @@ def _skip_by_device_name(self):
raise NotImplementedError(f"The current device {current_device_name} is skipped by its `{self.name}/metadata.yaml`.")

def _determine_dynamic_num_batches(self, user_specified_num_batches: Optional[int]) -> int:
if self.test == "train" or self.test == "eval":
return 1
if user_specified_num_batches:
if user_specified_num_batches and not user_specified_num_batches == 1:
assert self.test == "train", "Only train test support multiple batches at this moment."
return user_specified_num_batches
assert hasattr(self, 'DEFAULT_NUM_BATCH'), f"We expect all models with dynamic shapes specify field `DEFAULT_NUM_BATCHES`."
return self.DEFAULT_NUM_BATCH
# If user does not specify num_batch, run a single batch by default
return 1

def _get_batch_size_from_metadata(self) -> Optional[str]:
if self.device != "cuda":
Expand Down Expand Up @@ -310,8 +307,7 @@ def _invoke_staged_train_test(self, num_batch: int) -> None:
def invoke(self) -> Optional[Tuple[torch.Tensor]]:
if self.test == "train" and is_staged_train_test(self):
return self._invoke_staged_train_test(num_batch=self.num_batch)
if self.test == "train_dynamic":
return self._invoke_staged_train_test(num_batch=self.num_batch)
assert self.num_batch == 1, "Only staged_train_test supports multiple-batch testing at this time."
out = None
with nested(*self.run_contexts):
if self.test == "train":
Expand Down

0 comments on commit fd7f14e

Please sign in to comment.