Skip to content

Commit

Permalink
Python: Fix setup.py
Browse files Browse the repository at this point in the history
1. Redundant code removed

2. Include the C API headers in source distribution

This completely fixes #87
  • Loading branch information
guysz-nvidia committed Nov 13, 2024
1 parent a850a60 commit 32ded0e
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 54 deletions.
1 change: 0 additions & 1 deletion .github/workflows/cibw.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ jobs:
CIBW_TEST_COMMAND: "pytest {package}/tests"
CIBW_TEST_REQUIRES: "pytest Cython"
CIBW_REPAIR_WHEEL_COMMAND_LINUX: 'auditwheel repair -w {dest_dir} {wheel}'
CIBW_ENVIRONMENT: "C_INCLUDE_PATH=$(pwd)/c/include"
with:
output-dir: dist
package-dir: python
Expand Down
1 change: 1 addition & 0 deletions python/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
global-include *.pyx
global-exclude *.c
recursive-include include *
68 changes: 18 additions & 50 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,67 +16,35 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See LICENSE.txt for license information.

import os
import glob
import sysconfig
from distutils.sysconfig import get_python_lib
import shutil

from Cython.Build import cythonize
from pathlib import Path
from setuptools import setup
from setuptools.command.sdist import sdist
from setuptools.extension import Extension

cython_files = ["nvtx/**/*.pyx"]

try:
nthreads = int(os.environ.get("PARALLEL_LEVEL", "0") or "0")
except Exception:
nthreads = 0
# ../c/include
c_include_path = Path(__file__).parent.parent / 'c' / 'include'

include_dirs = [os.path.dirname(sysconfig.get_path("include")),]
if os.getenv("CUDA_HOME"):
include_dirs.insert(0, os.path.join(os.environ["CUDA_HOME"], "include"))
library_dirs = [get_python_lib()]
# When building from source distribution (.tar.gz), ./include dir exists (added by sdist command)
# Otherwise, we are building from sources, so we need to use `c_include_path`
include_dirs = ['include' if Path('include').exists() else str(c_include_path)]

if nvtx_include_dir := os.getenv("NVTX_PREFIX"):
include_dirs.insert(0, nvtx_include_dir)

extensions = [
Extension(
"*",
sources=cython_files,
include_dirs=include_dirs,
library_dirs=library_dirs,
language="c",
)
]

cython_tests = glob.glob("nvtx/_lib/tests/*.pyx")

# tests:
extensions += cythonize(
[
Extension(
"*",
sources=cython_tests,
include_dirs=include_dirs,
library_dirs=library_dirs,
language="c"
)
],
nthreads=nthreads,
compiler_directives=dict(
profile=True, language_level=3, embedsignature=True, binding=True
),
)
class NvtxSdist(sdist):
def run(self):
try:
shutil.copytree(c_include_path, 'include')
super().run()
finally:
shutil.rmtree('include', ignore_errors=True)


setup(
# Include the separately-compiled shared library
cmdclass=dict(sdist=NvtxSdist),
ext_modules=cythonize(
extensions,
nthreads=nthreads,
compiler_directives=dict(
profile=False, language_level=3, embedsignature=True
),
),
Extension('*', sources=['src/nvtx/_lib/*.pyx'], include_dirs=include_dirs),
compiler_directives=dict(language_level=3, embedsignature=True))
)
3 changes: 0 additions & 3 deletions python/tools/build-wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ function repair_wheel {
fi
}


export C_INCLUDE_PATH=/io/c/include

# Compile wheels
for PY_VERSION in 38 39 310 311; do
PYBIN=/opt/python/cp${PY_VERSION}*/bin/
Expand Down

0 comments on commit 32ded0e

Please sign in to comment.