Skip to content

Commit

Permalink
Add a pure python wrapper fo pybindings.{aten,portable}_lib (#3137)
Browse files Browse the repository at this point in the history
Summary:

When installed as a pip wheel, we must import `torch` before trying to import the pybindings shared library extension. This will load libtorch.so and related libs, ensuring that the pybindings lib can resolve those runtime dependencies.

So, add a pure python wrapper that lets us do this when users say `import executorch.extension.pybindings.portable_lib`

We only need this for OSS, so don't bother doing this for other pybindings targets.

Differential Revision: D56317150
  • Loading branch information
dbort committed Apr 18, 2024
1 parent 014aa37 commit 0a7566d
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 6 deletions.
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,11 @@ if(EXECUTORCH_BUILD_PYBIND)

# pybind portable_lib
pybind11_add_module(portable_lib extension/pybindings/pybindings.cpp)
# The actual output file needs a leading underscore so it can coexist with
# portable_lib.py in the same python package.
set_target_properties(portable_lib PROPERTIES OUTPUT_NAME "_portable_lib")
target_compile_definitions(portable_lib
PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=portable_lib)
PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=_portable_lib)
target_include_directories(portable_lib PRIVATE ${TORCH_INCLUDE_DIRS})
target_compile_options(portable_lib PUBLIC ${_pybind_compile_options})
target_link_libraries(
Expand Down
16 changes: 12 additions & 4 deletions extension/pybindings/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ runtime.genrule(
srcs = [":pybinding_types"],
outs = {
"aten_lib.pyi": ["aten_lib.pyi"],
"portable_lib.pyi": ["portable_lib.pyi"],
"_portable_lib.pyi": ["_portable_lib.pyi"],
},
cmd = "cp $(location :pybinding_types)/* $OUT/portable_lib.pyi && cp $(location :pybinding_types)/* $OUT/aten_lib.pyi",
cmd = "cp $(location :pybinding_types)/* $OUT/_portable_lib.pyi && cp $(location :pybinding_types)/* $OUT/aten_lib.pyi",
visibility = ["//executorch/extension/pybindings/..."],
)

Expand All @@ -46,8 +46,9 @@ executorch_pybindings(
executorch_pybindings(
compiler_flags = ["-std=c++17"],
cppdeps = PORTABLE_MODULE_DEPS + MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB,
python_module_name = "portable_lib",
types = ["//executorch/extension/pybindings:pybindings_types_gen[portable_lib.pyi]"],
# Give this an underscore prefix because it has a pure python wrapper.
python_module_name = "_portable_lib",
types = ["//executorch/extension/pybindings:pybindings_types_gen[_portable_lib.pyi]"],
visibility = ["PUBLIC"],
)

Expand All @@ -58,3 +59,10 @@ executorch_pybindings(
types = ["//executorch/extension/pybindings:pybindings_types_gen[aten_lib.pyi]"],
visibility = ["PUBLIC"],
)

runtime.python_library(
name = "portable_lib",
srcs = ["portable_lib.py"],
visibility = ["@EXECUTORCH_CLIENTS"],
deps = [":_portable_lib"],
)
32 changes: 32 additions & 0 deletions extension/pybindings/portable_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

# When installed as a pip wheel, we must import `torch` before trying to import
# the pybindings shared library extension. This will load libtorch.so and
# related libs, ensuring that the pybindings lib can resolve those runtime
# dependencies.
import torch as _torch

# Import the actual C++ extension that this file wraps.
from executorch.extension.pybindings import _portable_lib

# Let users import everything from _portable_lib as if this python file defined
# them. Normally we'd exclude names starting with `_`, but _portable_lib
# contains names like `_load_for_executorch` that we need to expose.
__all__ = [name for name in dir(_portable_lib) if not name.startswith("__")]

# The underscores also complicate things because it means we can't use `import
# *` to bring them into our namespace.
for _name in __all__:
exec(f"from executorch.extension.pybindings._portable_lib import {_name}")

# Clean up so that `dir(portable_lib)` is the same as `dir(_portable_lib)`
# (modulo some __dunder__ names).
del _name
del _portable_lib
del _torch
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def get_ext_modules() -> list[Extension]:
# portable kernels, and a selection of backends. This lets users
# load and execute .pte files from python.
BuiltExtension(
"portable_lib.*", "executorch.extension.pybindings.portable_lib"
"_portable_lib.*", "executorch.extension.pybindings._portable_lib"
)
)

Expand Down

0 comments on commit 0a7566d

Please sign in to comment.