Skip to content

Commit

Permalink
Move torchbench model configuration into a YAML file. (#120299)
Browse files Browse the repository at this point in the history
Summary:
This PR moves other aspects of torchbench's model configuration (e.g. batch size,
tolerance requirements, etc.) into a new YAML file: `torchbench.yaml`. It also merges the
recently added `torchbench_skip_models.yaml` file inside the `skip` key.

This is an effort so that external consumers are able to easily replicate the performance
results and coverage results from the PyTorch HUD.

X-link: pytorch/pytorch#120299
Approved by: https://github.com/jansel

Reviewed By: jeanschmidt

Differential Revision: D54123721

fbshipit-source-id: c6e69269775fa8a70021fe13313293a527c6b3e1
  • Loading branch information
ysiraichi authored and facebook-github-bot committed Feb 24, 2024
1 parent 4386604 commit a099658
Showing 1 changed file with 64 additions and 183 deletions.
247 changes: 64 additions & 183 deletions userbenchmark/dynamo/dynamobench/torchbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,167 +48,26 @@ def setup_torchbench_cwd():
return original_dir


# Some models have large dataset that doesn't fit in memory. Lower the batch
# size to test the accuracy.
USE_SMALL_BATCH_SIZE = {
"demucs": 4,
"dlrm": 1024,
"densenet121": 4,
"hf_Reformer": 4,
"hf_T5_base": 4,
"timm_efficientdet": 1,
"llama_v2_7b_16h": 1,
"yolov3": 8, # reduced from 16 due to cudagraphs OOM in TorchInductor dashboard
}

INFERENCE_SMALL_BATCH_SIZE = {
"timm_efficientdet": 32,
}

DETECTRON2_MODELS = {
"detectron2_fasterrcnn_r_101_c4",
"detectron2_fasterrcnn_r_101_dc5",
"detectron2_fasterrcnn_r_101_fpn",
"detectron2_fasterrcnn_r_50_c4",
"detectron2_fasterrcnn_r_50_dc5",
"detectron2_fasterrcnn_r_50_fpn",
"detectron2_maskrcnn_r_101_c4",
"detectron2_maskrcnn_r_101_fpn",
"detectron2_maskrcnn_r_50_fpn",
}

# These models support only train mode. So accuracy checking can't be done in
# eval mode.
ONLY_TRAINING_MODE = {
"tts_angular",
"tacotron2",
"demucs",
"hf_Reformer",
"pytorch_struct",
"yolov3",
}
ONLY_TRAINING_MODE.update(DETECTRON2_MODELS)

# Need lower tolerance on GPU. GPU kernels have non deterministic kernels for these models.
REQUIRE_HIGHER_TOLERANCE = {
"alexnet",
"attention_is_all_you_need_pytorch",
"densenet121",
"hf_Albert",
"vgg16",
"mobilenet_v3_large",
"nvidia_deeprecommender",
"timm_efficientdet",
}

# These models need >1e-3 tolerance
REQUIRE_EVEN_HIGHER_TOLERANCE = {
"soft_actor_critic",
"tacotron2",
}

REQUIRE_HIGHER_FP16_TOLERANCE = {
"doctr_reco_predictor",
"drq",
"hf_Whisper",
}


REQUIRE_HIGHER_BF16_TOLERANCE = {
"doctr_reco_predictor",
"drq",
"hf_Whisper",
}

REQUIRE_COSINE_TOLERACE = {
# Just keeping it here even though its empty, if we need this in future.
}

# non-deterministic output / cant check correctness
NONDETERMINISTIC = {
# https://github.com/pytorch/pytorch/issues/98355
"mobilenet_v3_large",
}

# These benchmarks took >600s on an i9-11900K CPU
VERY_SLOW_BENCHMARKS = {
"hf_BigBird", # 3339s
"hf_Longformer", # 3062s
"hf_T5", # 930s
}

# These benchmarks took >60s on an i9-11900K CPU
SLOW_BENCHMARKS = {
*VERY_SLOW_BENCHMARKS,
"BERT_pytorch", # 137s
"demucs", # 116s
"fastNLP_Bert", # 242s
"hf_Albert", # 221s
"hf_Bart", # 400s
"hf_Bert", # 334s
"hf_DistilBert", # 187s
"hf_GPT2", # 470s
"hf_Reformer", # 141s
"speech_transformer", # 317s
"vision_maskrcnn", # 99s
}

TRT_NOT_YET_WORKING = {
"alexnet",
"resnet18",
"resnet50",
"mobilenet_v2",
"mnasnet1_0",
"squeezenet1_1",
"shufflenetv2_x1_0",
"vgg16",
"resnext50_32x4d",
}

DONT_CHANGE_BATCH_SIZE = {
"demucs",
"pytorch_struct",
"pyhpc_turbulent_kinetic_energy",
"vision_maskrcnn", # https://github.com/pytorch/benchmark/pull/1656
}

MAX_BATCH_SIZE_FOR_ACCURACY_CHECK = {
"hf_GPT2": 2,
"pytorch_unet": 2,
}

FORCE_AMP_FOR_FP16_BF16_MODELS = {
"DALLE2_pytorch",
"doctr_det_predictor",
"doctr_reco_predictor",
"Super_SloMo",
"tts_angular",
"pyhpc_turbulent_kinetic_energy",
"detectron2_fcos_r_50_fpn",
}

FORCE_FP16_FOR_BF16_MODELS = {"vision_maskrcnn"}

# models in canary_models that we should run anyway
CANARY_MODELS = {
"torchrec_dlrm",
"clip", # torchbench removed torchtext dependency
}


@functools.lru_cache(maxsize=1)
def load_skip_file():
skip_file_name = "torchbench_skip_models.yaml"
skip_file_path = os.path.join(os.path.dirname(__file__), skip_file_name)
def load_yaml_file():
filename = "torchbench.yaml"
filepath = os.path.join(os.path.dirname(__file__), filename)

with open(skip_file_path) as f:
with open(filepath) as f:
data = yaml.safe_load(f)

def flatten(lst):
for item in lst:
if isinstance(item, list):
yield from flatten(item)
else:
yield item

def maybe_list_to_set(obj):
if isinstance(obj, dict):
return {k: maybe_list_to_set(v) for k, v in obj.items()}
if isinstance(obj, list):
return set(obj)
return set(flatten(obj))
return obj

return maybe_list_to_set(data)
Expand All @@ -221,68 +80,84 @@ def __init__(self):
self.optimizer = None

@property
def _skip_data(self):
return load_skip_file()
def _config(self):
return load_yaml_file()

@property
def _skip(self):
return self._config["skip"]

@property
def _batch_size(self):
return self._config["batch_size"]

@property
def _tolerance(self):
return self._config["tolerance"]

@property
def _accuracy(self):
return self._config["accuracy"]

@property
def skip_models(self):
return self._skip_data["skip"]
return self._skip["all"]

@property
def skip_models_for_cpu(self):
return self._skip_data["device"]["cpu"]
return self._skip["device"]["cpu"]

@property
def skip_models_for_cuda(self):
return self._skip_data["device"]["cuda"]
return self._skip["device"]["cuda"]

@property
def slow_models(self):
return SLOW_BENCHMARKS
return self._config["slow"]

@property
def very_slow_models(self):
return VERY_SLOW_BENCHMARKS
return self._config["very_slow"]

@property
def non_deterministic_models(self):
return NONDETERMINISTIC
return self._config["non_deterministic"]

@property
def skip_not_suitable_for_training_models(self):
return self._skip_data["test"]["train"]
return self._skip["test"]["training"]

@property
def failing_fx2trt_models(self):
return TRT_NOT_YET_WORKING
return self._config["trt_not_yet_working"]

@property
def force_amp_for_fp16_bf16_models(self):
return FORCE_AMP_FOR_FP16_BF16_MODELS
return self._config["dtype"]["force_amp_for_fp16_bf16_models"]

@property
def force_fp16_for_bf16_models(self):
return FORCE_FP16_FOR_BF16_MODELS
return self._config["dtype"]["force_fp16_for_bf16_models"]

@property
def skip_accuracy_checks_large_models_dashboard(self):
if self.args.dashboard or self.args.accuracy:
return self._skip_data["accuracy"]["large_models"]
return self._accuracy["skip"]["large_models"]
return set()

@property
def skip_accuracy_check_as_eager_non_deterministic(self):
if self.args.accuracy and self.args.training:
return self._skip_data["accuracy"]["eager_not_deterministic"]
return self._accuracy["skip"]["eager_not_deterministic"]
return set()

@property
def skip_multiprocess_models(self):
return self._skip_data["multiprocess"]
return self._skip["multiprocess"]

@property
def skip_models_due_to_control_flow(self):
return self._skip_data["control_flow"]
return self._skip["control_flow"]

def load_model(
self,
Expand Down Expand Up @@ -322,22 +197,26 @@ def load_model(

cant_change_batch_size = (
not getattr(benchmark_cls, "ALLOW_CUSTOMIZE_BSIZE", True)
or model_name in DONT_CHANGE_BATCH_SIZE
or model_name in self._config["dont_change_batch_size"]
)
if cant_change_batch_size:
batch_size = None
if batch_size is None and is_training and model_name in USE_SMALL_BATCH_SIZE:
batch_size = USE_SMALL_BATCH_SIZE[model_name]
if (
batch_size is None
and is_training
and model_name in self._batch_size["training"]
):
batch_size = self._batch_size["training"][model_name]
elif (
batch_size is None
and not is_training
and model_name in INFERENCE_SMALL_BATCH_SIZE
and model_name in self._batch_size["inference"]
):
batch_size = INFERENCE_SMALL_BATCH_SIZE[model_name]
batch_size = self._batch_size["inference"][model_name]

# Control the memory footprint for few models
if self.args.accuracy and model_name in MAX_BATCH_SIZE_FOR_ACCURACY_CHECK:
batch_size = min(batch_size, MAX_BATCH_SIZE_FOR_ACCURACY_CHECK[model_name])
if self.args.accuracy and model_name in self._accuracy["max_batch_size"]:
batch_size = min(batch_size, self._accuracy["max_batch_size"][model_name])

# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
torch.backends.__allow_nonbracketed_mutation_flag = True
Expand Down Expand Up @@ -378,7 +257,9 @@ def load_model(
model, example_inputs = benchmark.get_module()

# Models that must be in train mode while training
if is_training and (not use_eval_mode or model_name in ONLY_TRAINING_MODE):
if is_training and (
not use_eval_mode or model_name in self._config["only_training"]
):
model.train()
else:
model.eval()
Expand Down Expand Up @@ -412,7 +293,7 @@ def iter_model_names(self, args):
models += [
f
for f in _list_canary_model_paths()
if os.path.basename(f) in CANARY_MODELS
if os.path.basename(f) in self._config["canary_models"]
]
models.sort()

Expand Down Expand Up @@ -443,21 +324,21 @@ def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
cosine = self.args.cosine
# Increase the tolerance for torch allclose
if self.args.float16 or self.args.amp:
if name in REQUIRE_HIGHER_FP16_TOLERANCE:
if name in self._tolerance["higher_fp16"]:
return 1e-2, cosine
return 1e-3, cosine

if self.args.bfloat16:
if name in REQUIRE_HIGHER_BF16_TOLERANCE:
if name in self._tolerance["higher_bf16"]:
return 1e-2, cosine

if is_training and current_device == "cuda":
tolerance = 1e-3
if name in REQUIRE_COSINE_TOLERACE:
if name in self._tolerance["cosine"]:
cosine = True
elif name in REQUIRE_HIGHER_TOLERANCE:
elif name in self._tolerance["higher"]:
tolerance = 1e-3
elif name in REQUIRE_EVEN_HIGHER_TOLERANCE:
elif name in self._tolerance["even_higher"]:
tolerance = 8 * 1e-2
return tolerance, cosine

Expand Down

0 comments on commit a099658

Please sign in to comment.