diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index be241a4957..8796da69a6 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -716,7 +716,9 @@ def maybe_mark_profile(*args, **kwargs): with maybe_profile(args.export_profiler_trace) as p: if args.export_aot_inductor: - frozen_model_iter_fn = export_aot_inductor(model, example_inputs) + frozen_model_iter_fn = export_aot_inductor( + model, example_inputs, args.devices[0] + ) else: frozen_model_iter_fn = torch._dynamo.run(model_iter_fn) @@ -1165,7 +1167,7 @@ class AOTInductorModelCache: cache = dict() @classmethod - def load(cls, model, example_inputs): + def load(cls, model, example_inputs, device): key = weakref.ref(model) if key not in cls.cache: # Register the output dataclass to pytree @@ -1179,10 +1181,9 @@ def load(cls, model, example_inputs): module = torch.utils.cpp_extension.load_inline( name="aot_inductor", - cpp_sources=[aot_inductor_launcher], + cpp_sources=[aot_inductor_launcher(so_path, device)], functions=["run"], - extra_ldflags=[so_path], - with_cuda=True, + with_cuda=(device == "cuda"), ) value = { @@ -1211,8 +1212,8 @@ def opt_export(_, example_inputs): return opt_export -def export_aot_inductor(model, example_inputs): - module, exported = AOTInductorModelCache.load(model, example_inputs) +def export_aot_inductor(model, example_inputs, device): + module, exported = AOTInductorModelCache.load(model, example_inputs, device) def opt_aot_inductor(_, example_inputs, collect_outputs=False): example_args, example_kwargs = _normalize_bench_inputs(example_inputs) @@ -3596,8 +3597,9 @@ def run(runner, args, original_dir=None): elif args.backend or args.export_aot_inductor: if args.export_aot_inductor: assert not args.training, "AOTInductor only supports inference" - assert args.devices == ["cuda"], "AOTInductor only tested for CUDA" - optimize_ctx = export_aot_inductor + optimize_ctx = functools.partial( + export_aot_inductor, device=args.devices[0] + ) # AOTInductor doesn't support control flow yet runner.skip_models.update(runner.skip_models_due_to_control_flow)