Skip to content

Commit

Permalink
Enable bf16 on all models
Browse files Browse the repository at this point in the history
Summary: As the title says

Reviewed By: davidberard98

Differential Revision: D50752308

fbshipit-source-id: 925af557310bd42772f71e9d51190b45c0e2447b
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Oct 27, 2023
1 parent f235144 commit adc4b0c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchbenchmark/util/extra_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def check_precision(model: 'torchbenchmark.util.model.BenchmarkModel', precision
if precision == "fx_int8":
return model.device == 'cpu' and hasattr(model, "enable_fx_int8")
if precision == "bf16":
return model.device == 'cpu' and hasattr(model, "enable_bf16")
return True
if precision == "amp_fp16":
if model.test == 'eval' and model.device == 'cuda':
return True
Expand Down Expand Up @@ -77,7 +77,7 @@ def parse_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', ext
if not check_precision(model, dargs.precision):
raise NotImplementedError(f"precision value: {dargs.precision}, "
"amp is only supported if cuda+eval, or if `enable_amp` implemented,"
"or if model uses staged train interfaces (forward, backward, optimizer).")
"or if model uses staged train interfaces (forward, backward, optimizer_step).")
if not check_memory_layout(model, dargs.channels_last):
raise NotImplementedError(f"Specified channels_last: {dargs.channels_last} ,"
f" but the model doesn't implement the enable_channels_last() interface.")
Expand Down

0 comments on commit adc4b0c

Please sign in to comment.