From 0a7566d1fda872c5e5b879c2db530adcbd0dfe87 Mon Sep 17 00:00:00 2001 From: Dave Bort Date: Thu, 18 Apr 2024 14:41:14 -0700 Subject: [PATCH] Add a pure python wrapper fo pybindings.{aten,portable}_lib (#3137) Summary: When installed as a pip wheel, we must import `torch` before trying to import the pybindings shared library extension. This will load libtorch.so and related libs, ensuring that the pybindings lib can resolve those runtime dependencies. So, add a pure python wrapper that lets us do this when users say `import executorch.extension.pybindings.portable_lib` We only need this for OSS, so don't bother doing this for other pybindings targets. Differential Revision: D56317150 --- CMakeLists.txt | 5 ++++- extension/pybindings/TARGETS | 16 ++++++++++---- extension/pybindings/portable_lib.py | 32 ++++++++++++++++++++++++++++ setup.py | 2 +- 4 files changed, 49 insertions(+), 6 deletions(-) create mode 100644 extension/pybindings/portable_lib.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 7dd6239b81..5c1747953f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -548,8 +548,11 @@ if(EXECUTORCH_BUILD_PYBIND) # pybind portable_lib pybind11_add_module(portable_lib extension/pybindings/pybindings.cpp) + # The actual output file needs a leading underscore so it can coexist with + # portable_lib.py in the same python package. + set_target_properties(portable_lib PROPERTIES OUTPUT_NAME "_portable_lib") target_compile_definitions(portable_lib - PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=portable_lib) + PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=_portable_lib) target_include_directories(portable_lib PRIVATE ${TORCH_INCLUDE_DIRS}) target_compile_options(portable_lib PUBLIC ${_pybind_compile_options}) target_link_libraries( diff --git a/extension/pybindings/TARGETS b/extension/pybindings/TARGETS index 0b4e9ef304..9dee0e208b 100644 --- a/extension/pybindings/TARGETS +++ b/extension/pybindings/TARGETS @@ -30,9 +30,9 @@ runtime.genrule( srcs = [":pybinding_types"], outs = { "aten_lib.pyi": ["aten_lib.pyi"], - "portable_lib.pyi": ["portable_lib.pyi"], + "_portable_lib.pyi": ["_portable_lib.pyi"], }, - cmd = "cp $(location :pybinding_types)/* $OUT/portable_lib.pyi && cp $(location :pybinding_types)/* $OUT/aten_lib.pyi", + cmd = "cp $(location :pybinding_types)/* $OUT/_portable_lib.pyi && cp $(location :pybinding_types)/* $OUT/aten_lib.pyi", visibility = ["//executorch/extension/pybindings/..."], ) @@ -46,8 +46,9 @@ executorch_pybindings( executorch_pybindings( compiler_flags = ["-std=c++17"], cppdeps = PORTABLE_MODULE_DEPS + MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB, - python_module_name = "portable_lib", - types = ["//executorch/extension/pybindings:pybindings_types_gen[portable_lib.pyi]"], + # Give this an underscore prefix because it has a pure python wrapper. + python_module_name = "_portable_lib", + types = ["//executorch/extension/pybindings:pybindings_types_gen[_portable_lib.pyi]"], visibility = ["PUBLIC"], ) @@ -58,3 +59,10 @@ executorch_pybindings( types = ["//executorch/extension/pybindings:pybindings_types_gen[aten_lib.pyi]"], visibility = ["PUBLIC"], ) + +runtime.python_library( + name = "portable_lib", + srcs = ["portable_lib.py"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [":_portable_lib"], +) diff --git a/extension/pybindings/portable_lib.py b/extension/pybindings/portable_lib.py new file mode 100644 index 0000000000..6abc1db47e --- /dev/null +++ b/extension/pybindings/portable_lib.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +# When installed as a pip wheel, we must import `torch` before trying to import +# the pybindings shared library extension. This will load libtorch.so and +# related libs, ensuring that the pybindings lib can resolve those runtime +# dependencies. +import torch as _torch + +# Import the actual C++ extension that this file wraps. +from executorch.extension.pybindings import _portable_lib + +# Let users import everything from _portable_lib as if this python file defined +# them. Normally we'd exclude names starting with `_`, but _portable_lib +# contains names like `_load_for_executorch` that we need to expose. +__all__ = [name for name in dir(_portable_lib) if not name.startswith("__")] + +# The underscores also complicate things because it means we can't use `import +# *` to bring them into our namespace. +for _name in __all__: + exec(f"from executorch.extension.pybindings._portable_lib import {_name}") + +# Clean up so that `dir(portable_lib)` is the same as `dir(_portable_lib)` +# (modulo some __dunder__ names). +del _name +del _portable_lib +del _torch diff --git a/setup.py b/setup.py index 6458fdc10c..fa67bebf0b 100644 --- a/setup.py +++ b/setup.py @@ -435,7 +435,7 @@ def get_ext_modules() -> list[Extension]: # portable kernels, and a selection of backends. This lets users # load and execute .pte files from python. BuiltExtension( - "portable_lib.*", "executorch.extension.pybindings.portable_lib" + "_portable_lib.*", "executorch.extension.pybindings._portable_lib" ) )