Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Aug 24, 2023
1 parent bebaa2c commit bf07ea0
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
10 changes: 10 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[flake8]
show-source=true
statistics=true
max-line-length=80

exclude =
.git,
.github,
setup.py,
build,
4 changes: 2 additions & 2 deletions .github/workflows/run_tests_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
cuda: ["11.7"]
torch: ["1.13.1"]
cuda: ["11.6"]
torch: ["1.12.1"]
python-version: ["3.9"]
build_type: ["Release", "Debug"]

Expand Down
8 changes: 6 additions & 2 deletions fast_rnnt/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ endif()
pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs})
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)

if(UNIX AND NOT APPLE)
if(APPLE)
target_link_libraries(__fast_rnnt
PRIVATE
${TORCH_DIR}/lib/libtorch_python.dylib
)
elseif(UNIX)
target_link_libraries(_fast_rnnt
PRIVATE
${PYTHON_LIBRARY}
${TORCH_DIR}/lib/libtorch_python.so
)
endif()

8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def build_extension(self, ext: setuptools.extension.Extension):
cmake_args = "-DCMAKE_BUILD_TYPE=Release -DFT_BUILD_TESTS=OFF"

if make_args == "" and system_make_args == "":
make_args = ' -j '
make_args = " -j "

if "PYTHON_EXECUTABLE" not in cmake_args:
print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
Expand Down Expand Up @@ -89,17 +89,17 @@ def get_package_version():
latest_version = latest_version.strip('"')
return latest_version


def get_requirements():
with open("requirements.txt", encoding="utf8") as f:
requirements = f.read().splitlines()

return requirements


package_name = "fast_rnnt"

with open(
"fast_rnnt/python/fast_rnnt/__init__.py", "a"
) as f:
with open("fast_rnnt/python/fast_rnnt/__init__.py", "a") as f:
f.write(f"__version__ = '{get_package_version()}'\n")

setuptools.setup(
Expand Down

0 comments on commit bf07ea0

Please sign in to comment.