-
Notifications
You must be signed in to change notification settings - Fork 20
/
setup.py
41 lines (37 loc) · 1.09 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME
from torch.utils.cpp_extension import CppExtension, CUDAExtension
# In any case, include the CPU version
modules = [
CppExtension('torchsearchsorted.cpu',
['src/cpu/searchsorted_cpu_wrapper.cpp']),
]
# If nvcc is available, add the CUDA extension
if CUDA_HOME:
modules.append(
CUDAExtension('torchsearchsorted.cuda',
['src/cuda/searchsorted_cuda_wrapper.cpp',
'src/cuda/searchsorted_cuda_kernel.cu'])
)
tests_require = [
'pytest',
]
# Now proceed to setup
setup(
name='torchsearchsorted',
version='1.1',
description='A searchsorted implementation for pytorch',
keywords='searchsorted',
author='Antoine Liutkus',
author_email='[email protected]',
packages=find_packages(where='src'),
package_dir={"": "src"},
ext_modules=modules,
tests_require=tests_require,
extras_require={
'test': tests_require,
},
cmdclass={
'build_ext': BuildExtension
}
)