Skip to content

Commit

Permalink
Add torch._export.aot_load (#2119)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2119

X-link: pytorch/pytorch#117610

Add a torch._export.aot_load API that can load an AOTInductor-compiled model.so into a python executable.

Reviewed By: khabinov, angelayi

Differential Revision: D52825456

fbshipit-source-id: 1cc2e93f4621863d2535360edca2d72a305bafdf
  • Loading branch information
desertfire authored and facebook-github-bot committed Jan 18, 2024
1 parent 6a8b941 commit e85d944
Showing 1 changed file with 3 additions and 18 deletions.
21 changes: 3 additions & 18 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import torch._dynamo.utils
import torch._export
import torch.distributed
import torch.fx._pytree as fx_pytree
import torch.multiprocessing as mp
from scipy.stats import gmean, ttest_ind
from torch._dynamo.profiler import fx_insert_profiling, Profiler
Expand Down Expand Up @@ -1126,13 +1125,7 @@ def load(cls, model, example_inputs, device):
_register_dataclass_output_as_pytree(example_outputs)

so_path = torch._export.aot_compile(model, example_args, example_kwargs)

runner = (
torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)
if device == "cpu"
else torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1)
)
cls.cache[key] = runner
cls.cache[key] = torch._export.aot_load(so_path, device)

return cls.cache[key]

Expand All @@ -1152,19 +1145,11 @@ def opt_export(_, example_inputs):


def export_aot_inductor(model, example_inputs, device):
runner = AOTInductorModelCache.load(model, example_inputs, device)
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
optimized = AOTInductorModelCache.load(model, example_inputs, device)

def opt_aot_inductor(_, example_inputs, collect_outputs=False):
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)

flat_inputs = fx_pytree.tree_flatten_spec(
(example_args, example_kwargs), in_spec
)
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
return optimized(example_args, example_kwargs)

return opt_aot_inductor

Expand Down

0 comments on commit e85d944

Please sign in to comment.