From 285ad4dd43c426dc3240310643ea8885adc3f4a2 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 24 Aug 2023 15:08:05 +0800 Subject: [PATCH 1/2] Add github actions (#30) --- .github/scripts/install_cuda.sh | 72 ++++++++++ .github/scripts/install_cudnn.sh | 62 +++++++++ .github/scripts/install_torch.sh | 188 +++++++++++++++++++++++++++ .github/workflows/run_tests_cpu.yml | 124 ++++++++++++++++++ .github/workflows/run_tests_cuda.yml | 135 +++++++++++++++++++ .github/workflows/style_check.yml | 69 ++++++++++ 6 files changed, 650 insertions(+) create mode 100644 .github/scripts/install_cuda.sh create mode 100644 .github/scripts/install_cudnn.sh create mode 100644 .github/scripts/install_torch.sh create mode 100644 .github/workflows/run_tests_cpu.yml create mode 100644 .github/workflows/run_tests_cuda.yml create mode 100644 .github/workflows/style_check.yml diff --git a/.github/scripts/install_cuda.sh b/.github/scripts/install_cuda.sh new file mode 100644 index 0000000..f7a669a --- /dev/null +++ b/.github/scripts/install_cuda.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# +# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +echo "cuda version: $cuda" + +case "$cuda" in + 10.0) + url=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux + ;; + 10.1) + # WARNING: there are bugs in + # https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run + # with GCC 7. Please use the following version + url=http://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.243_418.87.00_linux.run + ;; + 10.2) + url=http://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run + ;; + 11.0) + url=http://developer.download.nvidia.com/compute/cuda/11.0.2/local_installers/cuda_11.0.2_450.51.05_linux.run + ;; + 11.1) + # url=https://developer.download.nvidia.com/compute/cuda/11.1.0/local_installers/cuda_11.1.0_455.23.05_linux.run + url=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run + ;; + 11.3) + # url=https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda_11.3.0_465.19.01_linux.run + url=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run + ;; + 11.5) + url=https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run + ;; + 11.6) + url=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run + ;; + 11.7) + url=https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda_11.7.1_515.65.01_linux.run + ;; + *) + echo "Unknown cuda version: $cuda" + exit 1 + ;; +esac + +function retry() { + $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) +} + +retry curl -LSs -O $url +filename=$(basename $url) +echo "filename: $filename" +chmod +x ./$filename +sudo ./$filename --toolkit --silent +rm -fv ./$filename + +export CUDA_HOME=/usr/local/cuda +export PATH=$CUDA_HOME/bin:$PATH +export LD_LIBRARY_PATH=$CUDA_HOME/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH diff --git a/.github/scripts/install_cudnn.sh b/.github/scripts/install_cudnn.sh new file mode 100644 index 0000000..d57018c --- /dev/null +++ b/.github/scripts/install_cudnn.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# +# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +case $cuda in + 10.0) + filename=cudnn-10.0-linux-x64-v7.6.5.32.tgz + ;; + 10.1) + filename=cudnn-10.1-linux-x64-v8.0.2.39.tgz + ;; + 10.2) + filename=cudnn-10.2-linux-x64-v8.0.2.39.tgz + ;; + 11.0) + filename=cudnn-11.0-linux-x64-v8.0.5.39.tgz + ;; + 11.1) + filename=cudnn-11.1-linux-x64-v8.0.4.30.tgz + ;; + 11.3) + filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz + ;; + 11.5) + filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz + ;; + 11.6) + filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz + ;; + 11.7) + filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz + ;; + *) + echo "Unsupported cuda version: $cuda" + exit 1 + ;; +esac + +command -v git-lfs >/dev/null 2>&1 || { echo >&2 "\nPlease install 'git-lfs' first."; exit 2; } + +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/cudnn +cd cudnn +git lfs pull --include="$filename" + +sudo tar xf ./$filename --strip-components=1 -C /usr/local/cuda + +# save disk space +git lfs prune && cd .. && rm -rf cudnn + +sudo sed -i '59i#define CUDNN_MAJOR 8' /usr/local/cuda/include/cudnn.h diff --git a/.github/scripts/install_torch.sh b/.github/scripts/install_torch.sh new file mode 100644 index 0000000..7ba7485 --- /dev/null +++ b/.github/scripts/install_torch.sh @@ -0,0 +1,188 @@ +#!/bin/bash +# +# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +torch=$TORCH_VERSION +cuda=$CUDA_VERSION +case ${torch} in + 1.5.*) + case ${cuda} in + 10.1) + package="torch==${torch}+cu101" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + 10.2) + package="torch==${torch}" + # Leave url empty to use PyPI. + # torch_stable provides cu92 but we want cu102 + url= + ;; + esac + ;; + 1.6.0) + case ${cuda} in + 10.1) + package="torch==1.6.0+cu101" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + 10.2) + package="torch==1.6.0" + # Leave it empty to use PyPI. + # torch_stable provides cu92 but we want cu102 + url= + ;; + esac + ;; + 1.7.*) + case ${cuda} in + 10.1) + package="torch==${torch}+cu101" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + 10.2) + package="torch==${torch}" + # Leave it empty to use PyPI. + # torch_stable provides cu92 but we want cu102 + url= + ;; + 11.0) + package="torch==${torch}+cu110" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + esac + ;; + 1.8.*) + case ${cuda} in + 10.1) + package="torch==${torch}+cu101" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + 10.2) + package="torch==${torch}" + # Leave it empty to use PyPI. + url= + ;; + 11.1) + package="torch==${torch}+cu111" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + esac + ;; + 1.9.*) + case ${cuda} in + 10.2) + package="torch==${torch}" + # Leave it empty to use PyPI. + url= + ;; + 11.1) + package="torch==${torch}+cu111" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + esac + ;; + 1.10.*) + case ${cuda} in + 10.2) + package="torch==${torch}" + # Leave it empty to use PyPI. + url= + ;; + 11.1) + package="torch==${torch}+cu111" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + 11.3) + package="torch==${torch}+cu113" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + esac + ;; + 1.11.*) + case ${cuda} in + 10.2) + package="torch==${torch}" + # Leave it empty to use PyPI. + url= + ;; + 11.3) + package="torch==${torch}+cu113" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + 11.5) + package="torch==${torch}+cu115" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + esac + ;; + 1.12.*) + case ${cuda} in + 10.2) + package="torch==${torch}" + # Leave it empty to use PyPI. + url= + ;; + 11.3) + package="torch==${torch}+cu113" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + 11.6) + package="torch==${torch}+cu116" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + esac + ;; + 1.13.*) + case ${cuda} in + 11.6) + package="torch==${torch}+cu116" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + 11.7) + package="torch==${torch}" + # Leave it empty to use PyPI. + url= + ;; + esac + ;; + 2.0.*) + case ${cuda} in + 11.7) + package="torch==${torch}+cu117" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + 11.8) + package="torch==${torch}+cu118" + url=https://download.pytorch.org/whl/torch_stable.html + ;; + esac + ;; + *) + echo "Unsupported PyTorch version: ${torch}" + exit 1 + ;; +esac + +function retry() { + $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) +} + +if [ x"${url}" == "x" ]; then + retry python3 -m pip install -q $package +else + retry python3 -m pip install -q $package -f $url +fi + +rm -rfv ~/.cache/pip diff --git a/.github/workflows/run_tests_cpu.yml b/.github/workflows/run_tests_cpu.yml new file mode 100644 index 0000000..51baada --- /dev/null +++ b/.github/workflows/run_tests_cpu.yml @@ -0,0 +1,124 @@ +# Copyright 2023 Xiaomi Corp. (Wei Kang) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# refer to https://github.com/actions/starter-workflows/pull/47/files + +name: run-tests-cpu + +on: + push: + branches: + - master + paths: + - '.github/workflows/run_tests_cpu.yml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'fast_rnnt/csrc/**' + - 'fast_rnnt/python/**' + pull_request: + branches: + - master + paths: + - '.github/workflows/run_tests_cpu.yml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'fast_rnnt/csrc/**' + - 'fast_rnnt/python/**' + +concurrency: + group: run-tests-cpu-${{ github.ref }} + cancel-in-progress: true + +jobs: + run-tests-cpu: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + torch: ["1.13.1"] + torchaudio: ["0.13.1"] + python-version: ["3.11"] + build_type: ["Release", "Debug"] + + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v2 + + - name: Display GCC version + run: | + gcc --version + + - name: Display clang version + if: startsWith(matrix.os, 'macos') + run: | + clang --version + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Display Python version + run: python -c "import sys; print(sys.version)" + + - name: Install PyTorch ${{ matrix.torch }} + if: startsWith(matrix.os, 'ubuntu') + shell: bash + run: | + python3 -m pip install -qq --upgrade pip + python3 -m pip install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html + python3 -c "import torch; print('torch version:', torch.__version__)" + + python3 -m torch.utils.collect_env + + - name: Install PyTorch ${{ matrix.torch }} + if: startsWith(matrix.os, 'macos') + shell: bash + run: | + python3 -m pip install -qq --upgrade pip + python3 -m pip install -qq torch==${{ matrix.torch }} + python3 -m pip install -qq torch==${{ matrix.torchaudio }} + python3 -c "import torch; print('torch version:', torch.__version__)" + + python3 -m torch.utils.collect_env + + - name: Configure CMake + shell: bash + env: + torch: ${{ matrix.torch }} + run: | + mkdir build + cd build + cmake -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DFT_WITH_CUDA=OFF .. + + - name: ${{ matrix.build_type }} Build + shell: bash + run: | + cd build + make -j2 VERBOSE=1 + + - name: Display Build Information + shell: bash + run: | + export PYTHONPATH=$PWD/fast_rnnt/python:$PWD/build/lib:$PYTHONPATH + + - name: Run Tests + shell: bash + run: | + cd build + ctest --output-on-failure diff --git a/.github/workflows/run_tests_cuda.yml b/.github/workflows/run_tests_cuda.yml new file mode 100644 index 0000000..1e85115 --- /dev/null +++ b/.github/workflows/run_tests_cuda.yml @@ -0,0 +1,135 @@ +# Copyright 2023 Xiaomi Corp. (Wei Kang) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-tests-cuda + +on: + push: + branches: + - master + paths: + - '.github/workflows/run_tests_cuda.yml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'fast_rnnt/csrc/**' + - 'fast_rnnt/python/**' + pull_request: + branches: + - master + paths: + - '.github/workflows/run_tests_cuda.yml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'fast_rnnt/csrc/**' + - 'fast_rnnt/python/**' + +concurrency: + group: run-tests-${{ github.ref }} + cancel-in-progress: true + +jobs: + run-tests: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + cuda: ["11.7"] + torch: ["1.13.1"] + torchaudio: ["1.13.0"] + python-version: ["3.11"] + build_type: ["Release", "Debug"] + + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v2 + + - name: Install CUDA Toolkit ${{ matrix.cuda }} + env: + cuda: ${{ matrix.cuda }} + run: | + source ./.github/scripts/install_cuda.sh + echo "CUDA_HOME=${CUDA_HOME}" >> $GITHUB_ENV + echo "${CUDA_HOME}/bin" >> $GITHUB_PATH + echo "LD_LIBRARY_PATH=${CUDA_HOME}/lib:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" >> $GITHUB_ENV + shell: bash + + - name: Display NVCC version + run: | + which nvcc + nvcc --version + + - name: Display GCC version + run: | + gcc --version + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Display Python version + run: python -c "import sys; print(sys.version)" + + - name: Install PyTorch ${{ matrix.torch }} + env: + cuda: ${{ matrix.cuda }} + torch: ${{ matrix.torch }} + shell: bash + run: | + python3 -m pip install -qq --upgrade pip + + ./.github/scripts/install_torch.sh + python3 -c "import torch; print('torch version:', torch.__version__)" + + - name: Install git lfs + run: | + sudo apt-get install -y git-lfs + + - name: Download cudnn 8.0 + env: + cuda: ${{ matrix.cuda }} + run: | + ./scripts/github_actions/install_cudnn.sh + + - name: Configure CMake + shell: bash + env: + torch: ${{ matrix.torch }} + run: | + mkdir build + cd build + cmake -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} .. + + - name: ${{ matrix.build_type }} Build + shell: bash + run: | + echo "number of cores: $(nproc)" + cd build + # we cannot use -j here because of limited RAM + # of the VM provided by GitHub actions + make VERBOSE=1 -j2 + + - name: Display Build Information + shell: bash + run: | + export PYTHONPATH=$PWD/fast_rnnt/python:$PWD/build/lib:$PYTHONPATH + + - name: Run Tests + shell: bash + run: | + cd build + ctest --output-on-failure diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml new file mode 100644 index 0000000..270ec85 --- /dev/null +++ b/.github/workflows/style_check.yml @@ -0,0 +1,69 @@ +# Copyright 2023 Xiaomi Corp. (Fangjun Kuang +# Wei Kang) +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: style_check + +on: + push: + branches: + - master + pull_request: + branches: + - master + +concurrency: + group: style_check-${{ github.ref }} + cancel-in-progress: true + +jobs: + style_check: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + run: | + python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 + # Click issue fixed in https://github.com/psf/black/pull/2966 + + - name: Run flake8 + shell: bash + working-directory: ${{github.workspace}} + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \ + --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503 + + - name: Run black + shell: bash + working-directory: ${{github.workspace}} + run: | + black --check --diff . From 6a4b834fad1d88d2a0556068bb238dc0a8eceb31 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 24 Aug 2023 17:23:04 +0800 Subject: [PATCH 2/2] Fix actions (#31) * Fix actions * change to py39 * Fix black * More fixes * Fixes * Fix torch version * Fix cudnn --- .flake8 | 10 ++++++++ .github/scripts/install_cuda.sh | 0 .github/scripts/install_cudnn.sh | 0 .github/scripts/install_torch.sh | 5 ++-- .github/workflows/run_tests_cpu.yml | 10 ++++---- .github/workflows/run_tests_cuda.yml | 9 ++++--- .github/workflows/style_check.yml | 2 +- CMakeLists.txt | 1 + fast_rnnt/python/csrc/CMakeLists.txt | 8 +++++-- fast_rnnt/python/fast_rnnt/__init__.py | 2 -- .../python/fast_rnnt/mutual_information.py | 18 ++++++++------ fast_rnnt/python/fast_rnnt/rnnt_loss.py | 24 ++++++++++++------- setup.py | 8 +++---- 13 files changed, 60 insertions(+), 37 deletions(-) create mode 100644 .flake8 mode change 100644 => 100755 .github/scripts/install_cuda.sh mode change 100644 => 100755 .github/scripts/install_cudnn.sh mode change 100644 => 100755 .github/scripts/install_torch.sh diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..f8611ac --- /dev/null +++ b/.flake8 @@ -0,0 +1,10 @@ +[flake8] +show-source=true +statistics=true +max-line-length=80 + +exclude = + .git, + .github, + setup.py, + build, diff --git a/.github/scripts/install_cuda.sh b/.github/scripts/install_cuda.sh old mode 100644 new mode 100755 diff --git a/.github/scripts/install_cudnn.sh b/.github/scripts/install_cudnn.sh old mode 100644 new mode 100755 diff --git a/.github/scripts/install_torch.sh b/.github/scripts/install_torch.sh old mode 100644 new mode 100755 index 7ba7485..5320c00 --- a/.github/scripts/install_torch.sh +++ b/.github/scripts/install_torch.sh @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -torch=$TORCH_VERSION -cuda=$CUDA_VERSION +echo "torch version: $torch" +echo "cuda version: $cuda" + case ${torch} in 1.5.*) case ${cuda} in diff --git a/.github/workflows/run_tests_cpu.yml b/.github/workflows/run_tests_cpu.yml index 51baada..1c1035e 100644 --- a/.github/workflows/run_tests_cpu.yml +++ b/.github/workflows/run_tests_cpu.yml @@ -49,9 +49,9 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - torch: ["1.13.1"] - torchaudio: ["0.13.1"] - python-version: ["3.11"] + torch: ["1.12.1"] + torchaudio: ["0.12.1"] + python-version: ["3.9"] build_type: ["Release", "Debug"] steps: @@ -81,7 +81,7 @@ jobs: run: | python3 -m pip install -qq --upgrade pip python3 -m pip install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html - python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html + python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }} -f https://download.pytorch.org/whl/cpu/torch_stable.html python3 -c "import torch; print('torch version:', torch.__version__)" python3 -m torch.utils.collect_env @@ -92,7 +92,7 @@ jobs: run: | python3 -m pip install -qq --upgrade pip python3 -m pip install -qq torch==${{ matrix.torch }} - python3 -m pip install -qq torch==${{ matrix.torchaudio }} + python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }} python3 -c "import torch; print('torch version:', torch.__version__)" python3 -m torch.utils.collect_env diff --git a/.github/workflows/run_tests_cuda.yml b/.github/workflows/run_tests_cuda.yml index 1e85115..8781ed6 100644 --- a/.github/workflows/run_tests_cuda.yml +++ b/.github/workflows/run_tests_cuda.yml @@ -47,10 +47,9 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - cuda: ["11.7"] - torch: ["1.13.1"] - torchaudio: ["1.13.0"] - python-version: ["3.11"] + cuda: ["11.6"] + torch: ["1.12.1"] + python-version: ["3.9"] build_type: ["Release", "Debug"] steps: @@ -103,7 +102,7 @@ jobs: env: cuda: ${{ matrix.cuda }} run: | - ./scripts/github_actions/install_cudnn.sh + ./.github/scripts/install_cudnn.sh - name: Configure CMake shell: bash diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 270ec85..24db8a4 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -66,4 +66,4 @@ jobs: shell: bash working-directory: ${{github.workspace}} run: | - black --check --diff . + black -l 80 --check --diff . diff --git a/CMakeLists.txt b/CMakeLists.txt index 6fb3496..becfe71 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,7 @@ endif() cmake_minimum_required(VERSION 3.8 FATAL_ERROR) +set(CMAKE_DISABLE_FIND_PACKAGE_MKL TRUE) set(languages CXX) set(_FT_WITH_CUDA ON) diff --git a/fast_rnnt/python/csrc/CMakeLists.txt b/fast_rnnt/python/csrc/CMakeLists.txt index 7bb8d08..2189d9e 100644 --- a/fast_rnnt/python/csrc/CMakeLists.txt +++ b/fast_rnnt/python/csrc/CMakeLists.txt @@ -16,11 +16,15 @@ endif() pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs}) target_link_libraries(_fast_rnnt PRIVATE mutual_information_core) -if(UNIX AND NOT APPLE) +if(APPLE) + target_link_libraries(_fast_rnnt + PRIVATE + ${TORCH_DIR}/lib/libtorch_python.dylib + ) +elseif(UNIX) target_link_libraries(_fast_rnnt PRIVATE ${PYTHON_LIBRARY} ${TORCH_DIR}/lib/libtorch_python.so ) endif() - diff --git a/fast_rnnt/python/fast_rnnt/__init__.py b/fast_rnnt/python/fast_rnnt/__init__.py index 2883c47..2fccfd6 100644 --- a/fast_rnnt/python/fast_rnnt/__init__.py +++ b/fast_rnnt/python/fast_rnnt/__init__.py @@ -13,5 +13,3 @@ from .rnnt_loss import rnnt_loss_pruned from .rnnt_loss import rnnt_loss_simple from .rnnt_loss import rnnt_loss_smoothed - - diff --git a/fast_rnnt/python/fast_rnnt/mutual_information.py b/fast_rnnt/python/fast_rnnt/mutual_information.py index 2d8029d..6d704fe 100644 --- a/fast_rnnt/python/fast_rnnt/mutual_information.py +++ b/fast_rnnt/python/fast_rnnt/mutual_information.py @@ -160,7 +160,8 @@ def forward( if return_grad or px.requires_grad or py.requires_grad: ans_grad = torch.ones(B, device=px.device, dtype=px.dtype) (px_grad, py_grad) = _fast_rnnt.mutual_information_backward( - px, py, boundary, p, ans_grad) + px, py, boundary, p, ans_grad + ) ctx.save_for_backward(px_grad, py_grad) assert len(pxy_grads) == 2 pxy_grads[0] = px_grad @@ -290,8 +291,9 @@ def mutual_information_recursion( px, py = px.contiguous(), py.contiguous() pxy_grads = [None, None] - scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads, - boundary, return_grad) + scores = MutualInformationRecursionFunction.apply( + px, py, pxy_grads, boundary, return_grad + ) px_grad, py_grad = pxy_grads return (scores, (px_grad, py_grad)) if return_grad else scores @@ -388,16 +390,18 @@ def joint_mutual_information_recursion( p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype) # note, tot_probs is without grad. - tot_probs = _fast_rnnt.mutual_information_forward(px_tot, py_tot, boundary, p) + tot_probs = _fast_rnnt.mutual_information_forward( + px_tot, py_tot, boundary, p + ) # this is a kind of "fake gradient" that we use, in effect to compute # occupation probabilities. The backprop will work regardless of the # actual derivative w.r.t. the total probs. ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype) - (px_grad, - py_grad) = _fast_rnnt.mutual_information_backward(px_tot, py_tot, boundary, p, - ans_grad) + (px_grad, py_grad) = _fast_rnnt.mutual_information_backward( + px_tot, py_tot, boundary, p, ans_grad + ) px_grad = px_grad.reshape(1, B, -1) py_grad = py_grad.reshape(1, B, -1) diff --git a/fast_rnnt/python/fast_rnnt/rnnt_loss.py b/fast_rnnt/python/fast_rnnt/rnnt_loss.py index 1667e87..622aa46 100644 --- a/fast_rnnt/python/fast_rnnt/rnnt_loss.py +++ b/fast_rnnt/python/fast_rnnt/rnnt_loss.py @@ -170,7 +170,7 @@ def get_rnnt_logprobs( am.transpose(1, 2), # (B, C, T) dim=1, index=symbols.unsqueeze(2).expand(B, S, T), - ) # (B, S, T) + ) # (B, S, T) if rnnt_type == "regular": px_am = torch.cat( @@ -291,7 +291,9 @@ def rnnt_loss_simple( T = T0 if rnnt_type != "regular" else T0 - 1 if boundary is None: offset = torch.tensor( - (T - 1) / 2, dtype=px.dtype, device=px.device, + (T - 1) / 2, + dtype=px.dtype, + device=px.device, ).expand(B, 1, 1) else: offset = (boundary[:, 3] - 1) / 2 @@ -495,7 +497,9 @@ def rnnt_loss( T = T0 if rnnt_type != "regular" else T0 - 1 if boundary is None: offset = torch.tensor( - (T - 1) / 2, dtype=px.dtype, device=px.device, + (T - 1) / 2, + dtype=px.dtype, + device=px.device, ).expand(B, 1, 1) else: offset = (boundary[:, 3] - 1) / 2 @@ -770,9 +774,7 @@ def do_rnnt_pruning( lm_pruning = torch.gather( lm, dim=1, - index=ranges.reshape(B, T * s_range, 1).expand( - (B, T * s_range, C) - ), + index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, C)), ).reshape(B, T, s_range, C) return am_pruning, lm_pruning @@ -1057,7 +1059,9 @@ def rnnt_loss_pruned( T = T0 if rnnt_type != "regular" else T0 - 1 if boundary is None: offset = torch.tensor( - (T - 1) / 2, dtype=px.dtype, device=px.device, + (T - 1) / 2, + dtype=px.dtype, + device=px.device, ).expand(B, 1, 1) else: offset = (boundary[:, 3] - 1) / 2 @@ -1248,7 +1252,7 @@ def get_rnnt_logprobs_smoothed( am.transpose(1, 2), # (B, C, T) dim=1, index=symbols.unsqueeze(2).expand(B, S, T), - ) # (B, S, T) + ) # (B, S, T) if rnnt_type == "regular": px_am = torch.cat( @@ -1413,7 +1417,9 @@ def rnnt_loss_smoothed( T = T0 if rnnt_type != "regular" else T0 - 1 if boundary is None: offset = torch.tensor( - (T - 1) / 2, dtype=px.dtype, device=px.device, + (T - 1) / 2, + dtype=px.dtype, + device=px.device, ).expand(B, 1, 1) else: offset = (boundary[:, 3] - 1) / 2 diff --git a/setup.py b/setup.py index 1d1c775..bed21cb 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ def build_extension(self, ext: setuptools.extension.Extension): cmake_args = "-DCMAKE_BUILD_TYPE=Release -DFT_BUILD_TESTS=OFF" if make_args == "" and system_make_args == "": - make_args = ' -j ' + make_args = " -j " if "PYTHON_EXECUTABLE" not in cmake_args: print(f"Setting PYTHON_EXECUTABLE to {sys.executable}") @@ -89,17 +89,17 @@ def get_package_version(): latest_version = latest_version.strip('"') return latest_version + def get_requirements(): with open("requirements.txt", encoding="utf8") as f: requirements = f.read().splitlines() return requirements + package_name = "fast_rnnt" -with open( - "fast_rnnt/python/fast_rnnt/__init__.py", "a" -) as f: +with open("fast_rnnt/python/fast_rnnt/__init__.py", "a") as f: f.write(f"__version__ = '{get_package_version()}'\n") setuptools.setup(